#!/usr/bin/env python

# Copyright (C) 2014  Sangoma Technologies Corp.
# All Rights Reserved.
#
# Author(s):
# Leonardo Lang <lang@sangoma.com>

GRUB_CONF = '/boot/grub/grub.conf'
GRUB_TEMP = GRUB_CONF + '.new'

PROD_DEFS = '/usr/local/sng/conf/prod-def.xml'

SERIAL_WHITELIST = [
    dict(dmi=dict(chassis_vendor=['QEMU', 'Xen', 'VirtualBox'])),
    dict(pci=dict(vendor=0x8086, device=[0x0100,0x0104]))
]

SERIAL_BLACKLIST = [
    dict(pci=dict(vendor=0x8086, device=[0x0bf0,0x0bf1,0x0bf2,0x0bf3,0x0bf4,0x0bf5,0x0bf6,0x0bf7])),
]

KERN_EXTRA_ARGS = [
    'panic=15',                     # if the kernel panics, make the system reboot (issue #10352)
    'crashkernel=128M@64M',         # reserve space for running kdump as the crashkernel (issue 10363)
    'xen_emul_unplug=unnecessary',  # don't unplug the xen emulated devices so the crashkernel can work on xen guests (issue #10524)
]

CONSOLE_SERIAL = 'console=ttyS0,115200n8'
CONSOLE_STDOUT = 'console=tty0'

ALL_CONSOLE_LIST = [ CONSOLE_SERIAL, CONSOLE_STDOUT ]

CFG_COLOUR_LINE = '\ncolor black/cyan white/blue\n'

REGEX_ATTRS_FORMAT = '[\t ]+([^\t ]+)[\t ]+(.*?)[\n]'

SYSFS_PCI = '/sys/bus/pci/devices/'

import os
import re
import sys

import subprocess
import logging
import logging.handlers

from xml.dom.minidom import parse as xmlParse

REGEX_ENTRY = re.compile('title (.*?)[\n](({attr})+)'.format(attr=REGEX_ATTRS_FORMAT))
REGEX_ATTRS = re.compile('{attr}'.format(attr=REGEX_ATTRS_FORMAT))

REGEX_VISUAL = re.compile('^(color|splashimage|hiddenmenu)')
REGEX_SERIAL = re.compile('^(serial|terminal)')
REGEX_FILTER = re.compile('^(title|[\t ])')

REGEX_RENAME = re.compile('^(.*?) [(](.*?)[)] (.*?)$')

def setup_logger():
    logger = logging.getLogger(os.path.basename(sys.argv[0]))
    logger.addHandler(logging.handlers.SysLogHandler())
    if os.isatty(1):
        logger.addHandler(logging.StreamHandler())
    logger.setLevel(logging.INFO)
    return logger

logger = setup_logger()

def read_data(fn):
    try:
        rfd = open(fn)
        data = rfd.read(1024).strip()
        rfd.close()
        return data
    except:
        return None

def pci_devices():
    for fname in os.listdir(SYSFS_PCI):
        yield fname

def pci_device_attr(name, attr):
    fname = os.path.join(SYSFS_PCI, name, attr)
    try:
        fdes = open(fname)
        ret = fdes.readline().strip()
        fdes.close()
        return ret
    except:
        return None

def get_product_version():
    def gettag(elm, name):
        nodes = []
        for node in elm.childNodes:
            if not hasattr(node, 'tagName'):
                continue
            if getattr(node, 'tagName') == name:
                nodes.append(node)
        return nodes

    def gettexttag(elm, name):
        for node in gettag(elm, name):
            if len(node.childNodes) == 0:
                continue
            return node.childNodes[0].data
        return ''

    try:
        proddefs = xmlParse(PROD_DEFS)
        filename = gettexttag(gettag(proddefs, 'product')[0], 'about')
        filedata = xmlParse(filename)
        return gettag(filedata, 'product')[0].getAttribute('product_version')
    except:
        exctype, excname = sys.exc_info()[:2]
        logger.warning('unable to get product version: {ea!s}({eb!s})'.format(ea=exctype.__name__, eb=excname))

def write_grub_entry(fd, title, attrdata):
    fd.write('\ntitle {name}\n'.format(name=title))
    for attr in attrdata:
        fd.write('\t{key} {val}\n'.format(key=attr[0], val=attr[1]))

def process_attributes(attr_list, with_serial, with_stdout):
    opts_serial, opts_stdout = [], []

    for attrdata in attr_list:
        if attrdata[0] != 'kernel':
            if with_stdout:
                opts_stdout.append(attrdata)
            if with_serial:
                opts_serial.append(attrdata)
            continue

        new_args = [ e for e in attrdata[1].split(' ') if e not in ALL_CONSOLE_LIST ]

        for extra_arg in KERN_EXTRA_ARGS:
             if extra_arg not in new_args:
                new_args.insert(3, extra_arg)

        def append_conls(optls, conflag, conls, allconls):
            if not conflag:
                optls.append(attrdata)
                return

            argls = list(new_args)
            argls.insert(3, ' '.join(allconls if conflag else conls))
            optls.append((attrdata[0], ' '.join(argls)))

        append_conls(opts_stdout, with_stdout, [ CONSOLE_STDOUT ], ALL_CONSOLE_LIST)
        append_conls(opts_serial, with_serial, [ CONSOLE_SERIAL ], reversed(ALL_CONSOLE_LIST))

    return (opts_serial, opts_stdout)

def process_grub_conf(with_serial=False, with_stdout=False):
    grubfd = open(GRUB_CONF)

    entrydata, headerdata = [], []
    prodversion, linebreak = None, True

    try:
        for line in grubfd:
            if line == '"\n': # remove anaconda garbage
                continue

            if line == '\n':
                linebreak = True
                continue

            if REGEX_FILTER.match(line):
                entrydata.append(line)
                continue

            if REGEX_SERIAL.match(line):
                continue

            if REGEX_VISUAL.match(line):
                if not line.startswith('splashimage='):
                    continue
                line = '#{ln}'.format(ln=line)

            if linebreak:
                headerdata.append('\n')
                linebreak = False

            headerdata.append(line)
    except:
        exctype, excname = sys.exc_info()[:2]
        logger.warning('unable to read grub config file: {ea!s}({eb!s})'.format(ea=exctype.__name__, eb=excname))
        return

    finally:
        grubfd.close()

    grubfd = open(GRUB_TEMP, 'w')

    try:
        for line in headerdata:
            grubfd.write(line)

        grubfd.write(CFG_COLOUR_LINE)

        if with_serial and with_stdout:
            grubfd.write('\nserial --unit=0 --speed=115200 --parity=no --stop=1\n')
            grubfd.write('terminal --timeout=3 serial console\n')

        for entry in REGEX_ENTRY.findall(''.join(entrydata)):
            rm_rename = REGEX_RENAME.match(entry[0])
            opts_data = REGEX_ATTRS.findall(entry[1])

            if rm_rename is None or (not with_serial and not with_stdout):
                write_grub_entry(grubfd, entry[0], opts_data)
                continue

            if prodversion is None:
                prodversion = get_product_version()

            new_title = '{name} {version}'.format(name=rm_rename.group(1), version=prodversion)

            stdout_title = '{nt}{ne}'.format(nt=new_title, ne=(' (standard console)' if with_stdout and with_serial else ''))
            serial_title = '{nt}{ne}'.format(nt=new_title, ne=(' (serial console)'   if with_stdout and with_serial else ''))

            opts_serial, opts_stdout = process_attributes(opts_data, with_serial, with_stdout)

            if with_stdout:
                write_grub_entry(grubfd, stdout_title, opts_stdout)

            if with_serial:
                write_grub_entry(grubfd, serial_title, opts_serial)

    except:
        exctype, excname = sys.exc_info()[:2]
        logger.warning('unable to process grub data: {ea!s}({eb!s})'.format(ea=exctype.__name__, eb=excname))
        subprocess.call(['rm', '-f', GRUB_TEMP])
        return

    finally:
        grubfd.close()

    subprocess.call(['mv', '-f', GRUB_TEMP, GRUB_CONF])

def setup_serial_only():
    process_grub_conf(with_serial=True)

def setup_serial_vga():
    process_grub_conf(with_serial=True, with_stdout=True)

def find_video_device():
    for name in pci_devices():
        if int(pci_device_attr(name, 'class'), 16) != 0x30000:
            continue
        return True
    return False

def match_pci_list(name, data):
    for item in data:
        if 'pci' not in item:
            continue

        if int(pci_device_attr(name, 'vendor'), 16) != item['pci']['vendor']:
            continue

        gotdev = int(pci_device_attr(name, 'device'), 16)
        lstdev = item['pci']['device']

        if isinstance(lstdev, list):
            if gotdev in lstdev:
                return True
        else:
            if gotdev == lstdev:
                return True

    return False

def match_dmi_list(data):
    for item in data:
        if 'dmi' not in item:
            continue

        for keyname, keydata in item['dmi'].items():
            gotdata = read_data('/sys/devices/virtual/dmi/id/{k}'.format(k=keyname))

            if isinstance(keydata, list):
                if gotdata in keydata:
                    return True
            else:
                if keydata == gotdata:
                    return True

    return False

def scan_lists():
    if match_dmi_list(SERIAL_BLACKLIST):
        return False

    if match_dmi_list(SERIAL_WHITELIST):
        return True

    # scan serial console whitelist
    for name in pci_devices():

        if match_pci_list(name, SERIAL_BLACKLIST):
            return False

        if match_pci_list(name, SERIAL_WHITELIST):
            return True

    return None

def setup_output():
    with_video = find_video_device()
    what_match = scan_lists()

    if what_match == False:
        logger.info('found in blacklist, not touching config')
        return

    if what_match == True and with_video:
        logger.info('found in whitelist and we have VGA, enabling serial console')
        setup_serial_vga()
        return

    if not with_video:
        logger.info('not found in any list and we do not have VGA, enabling serial console only')
        setup_serial_only()
        return

    logger.info('{e}found in lists and we do have VGA, not touching config'.format(e='not ' if what_match is None else ''))

def main():
    try:
        setup_output()
    except:
        exctype, excname = sys.exc_info()[:2]
        logger.warning('unable to update grub config: {ea!s}({eb!s})'.format(ea=exctype.__name__, eb=excname))

main()
