#!/usr/bin/env python
# vim: tabstop=4 softtabstop=4 shiftwidth=4 textwidth=80 smarttab expandtab
"""
* Copyright (C) 2014  Sangoma Technologies Corp.
* All Rights Reserved.
*
* This code is Sangoma Technologies Confidential Property.
* Use of and access to this code is covered by a previously executed
* non-disclosure agreement between Sangoma Technologies and the Recipient.
* This code is being supplied for evaluation purposes only and is not to be
* used for any other purpose.
"""

import os
import re
import sys
import shutil
import logging
import logging.handlers
import fcntl
import traceback
from optparse import OptionParser
import logging
import logging.handlers
import subprocess
import signal
import time

snort_lock_file = sys.argv[0]

class SnortConfUpdater:

    def __init__(self):

        self.config_file = {
               'main': '/etc/snort.conf',
               'http': '/etc/snort/http.conf',
                'sip': '/etc/snort/sip.conf',
                'ssh': '/etc/snort/ssh.conf',
            'http-pp': '/etc/snort/http-preprocess.conf',
             'ssl-pp': '/etc/snort/ssl-preprocess.conf',
             'sip-pp': '/etc/snort/sip-preprocess.conf',
             'ssh-pp': '/etc/snort/ssh-preprocess.conf',
            'network': '/etc/snort/network.conf'
        }

        self._script_name = 'snortconfupdater'

        self._logformat = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        self._logger = logging.getLogger(self._script_name)
        self._logger.setLevel(logging.INFO)

        formatter = logging.Formatter(self._logformat)

        log_handler = logging.handlers.RotatingFileHandler('/var/log/' + self._script_name + '.log', maxBytes=10*1024*1024, backupCount=4)
        log_handler.setFormatter(formatter)
        self._logger.addHandler(log_handler)

        # custom SnortConfUpdater options should be added here via parser.add_option()
        self.parser = OptionParser()

        self.parser.add_option("", "--init", dest="init", action='store_true',
                        help="Initialize snort configuration")
        self.parser.add_option("", "--debug", dest="debug", action='store_true',
                        help="Enable debug mode")
        self.parser.add_option("", "--no-reload", dest="skip_reload", action='store_true',
                        help="Don't reload configuration")
        self.parser.add_option("", "--reload", dest="force_reload", action='store_true',
                        help="Reload configuration")

        self.parser.add_option("", "--running", dest="running", action='store_true',
                        help="Check if any snort instance is running")
        self.parser.add_option("", "--on-interface", dest="iface",
                        help="Check if snort is running on IFACE")

        self.parser.add_option("", "--home-network", dest="home_network",
                        help="Update snort home network IPs")
        self.parser.add_option("", "--ssh-ports", dest="ssh_port_list",
                        help="Update snort SSH ports")
        self.parser.add_option("", "--http-ports", dest="http_port_list",
                        help="Update snort HTTP ports")
        self.parser.add_option("", "--ssl-ports", dest="ssl_port_list",
                        help="Update snort HTTPS ports")
        self.parser.add_option("", "--ports", dest="sip_port_list",
                        help="Update snort SIP ports")

        (self.options, args) = self.parser.parse_args()

    def do_process_config(self, cfgname, replace_vars, replace_conf):
        global contline
        global contdata

        cfgfile = self.config_file[cfgname]
      
        regcomm = re.compile('^((#.*\n)|([ ]*)\n)$')
        regincl = re.compile('^(include|dynamic(preprocessor|engine|detection)|output|alert) (.+)\n$')
        regvars = re.compile('^(var|ipvar|portvar)[ \t]+([A-Za-z0-9_]+)[ \t]+(.+)\n$')
        regconf = re.compile('^(config|preprocessor)[ \t]+([A-Za-z0-9_]+)(:[ \t]+([^\\\\]*))?([\\\\][ ]*)?\n$')
        regnext = re.compile('^([^\\\\]+)([\\\\][ ]*)?\n$')

        regrepl = {}
        regrepl['value'] = re.compile('(\\[.+?\\]|any)')
        regrepl['ports'] = re.compile('(server_)?ports [{] ([0-9 ]+) [}]')

        varstable = {}
        filecache = {}

        def substitute_vars(strtext):
            if len(varstable) == 0:
                self._logger.debug('variable table empty, skipping..')
                return strtext

            restr = '[$](' + '|'.join(varstable) + ')'
            self._logger.debug('substituting, regexp: ' + restr)
            rvar = re.compile(restr)
            posn = 0
            while True:
                mvar = rvar.search(strtext, posn)
                if mvar:
                    varvalue = varstable.get(mvar.group(1))
                    if varvalue:
                        self._logger.debug('found "' + mvar.group(1) + '", replacing by "' + varvalue + '"...')
                        strtext = strtext[:mvar.start()] + varvalue + strtext[mvar.end():]
                        posn = mvar.end()
                    else:
                        break
                else:
                    break
            return strtext
            
        def handle_data(inpname, outfile, line, stripline, data, value=None):
            if data[0] == 'disable':
                self._logger.info('commenting line: ' + stripline)
                outfile.write('# ' + line)
                return True

            elif data[0] == 'replace':
                self._logger.debug('checking replace at "' + stripline + '"...')
                try:
                    replaced = False
                    for rkey in data[1].keys():
                        if rkey == 'value':
                            if value:
                                if line[value[0]:value[1]] <> data[1][rkey]:
                                    self._logger.info('replacing value "' + line[value[0]:value[1]] + '" with "' + data[1][rkey] + '" at: ' + stripline)
                                    newline = line[:value[0]] + data[1][rkey] + line[value[1]:]
                                    outfile.write(newline)
                                    replaced = True
                                else:
                                    self._logger.info('skipping replacement of "' + data[1][rkey] + '", values are the same.')
                                break
                            else:
                                self._logger.warning('got "value" replace but empty value, skipping...')
                        else:
                            self._logger.debug('matching "' + rkey + '" expression...')
                            mrepl = regrepl[rkey].search(line)
                            if mrepl is not None:
                                if line[mrepl.start():mrepl.end()] <> data[1][rkey]:
                                    self._logger.info('replacing match "' + line[mrepl.start():mrepl.end()] + '" with "' + data[1][rkey] + '" at: ' + stripline)
                                    newline = line[:mrepl.start()] + data[1][rkey] + line[mrepl.end():]
                                    outfile.write(newline)
                                    replaced = True
                                else:
                                    self._logger.info('skipping replacement of match "' + data[1][rkey] + '", values are the same.')
                                break
                    if not replaced:
                        outfile.write(line)
                    
                    return replaced
                
                except Exception, e:
                    outfile.write(line)
                    return False

            elif data[0] == 'move':
                fname = self.config_file.get(data[1])
                if fname is not None and fname == inpname:
                    self._logger.debug('line "' + stripline + '" already moved to "' + data[1] + '"...')
                    outfile.write(line)
                    return False
                else:
                    self._logger.info('moving line "' + stripline + '" to "' + data[1] + '" config...')
                    outmove = filecache.get(data[1])
                    if outmove is None:
                        if fname is None:
                            self._logger.error('no config file id "' + data[1] + '" found, skipping..')
                            outmove = outfile
                        else:
                            self._logger.debug('opening file "' + fname + '" for moving data...')

                            outmove = open(fname, 'w')
                            outmove.write('#### Auto-generated from ' + sys.argv[0] + ' ####\n\n')

                            outfile.write('\n#### Modified by ' + sys.argv[0] + ' ####\n')
                            outfile.write('include ' + fname + '\n\n')

                            filecache[data[1]] = outmove

                    outmove.write(line)
                    return True
            elif data[0] == 'enable':
                self._logger.debug('line already enabled...')
                outfile.write(line)
                return False
            else:
                self._logger.error('unknown action "' + data[0] + '" at "' + stripline + '", skipping...')
                outfile.write(line)
                return False

        contline = False
        contdata = None

        def process_file(inpname, inpfile, outfile):
            self._logger.info('processing file ' + inpname + '...')

            retfinal = False

            global contline
            global contdata
            
            for line in inpfile.readlines():
                stripline = line.rstrip('\n')
                if contline:
                    self._logger.debug('parsing multi-line at "' + stripline + '"...')
                    mnext = regnext.match(line)
                    if mnext is not None:
                        if contdata is not None:
                            ret = handle_data(inpname, outfile, line, stripline, contdata)
                            retfinal = ret or retfinal
                        else:
                            outfile.write(line)
                            
                        if mnext.group(2) is None:
                            self._logger.debug('exiting multi-line mode...')
                            contline = False
                            contdata = None

                        continue
                    else:
                        self._logger.warning('multi-line parsing error at "' + stripline + '", not processing.')
                        outfile.write(line)
                        contline = False
                        contdata = None

                if regcomm.match(line):
                    if line.startswith('# '):
                        data = None

                        mcvars = regvars.match(line[2:])
                        if mcvars is not None:
                            try:
                                data = replace_vars[mcvars.group(1)][mcvars.group(2)]
                            except KeyError, e:
                                pass

                        if data is not None and data[0] == 'enable':
                            self._logger.info('enabling line: ' + line[2:].rstrip('\n'))
                            outfile.write(line[2:])
                        else:
                            outfile.write(line)
                    else:
                        outfile.write(line)

                    continue

                mincl = regincl.match(line)
                if mincl is not None:
                    if mincl.group(1) == 'include':
                        self._logger.debug('requesting include of "' + mincl.group(3) + '"...')
                        incname = substitute_vars(mincl.group(3))
                        self._logger.debug('scanning included "' + incname + '"...')
                        replace_file(incname)

                    outfile.write(line)
                    continue
                
                mvars = regvars.match(line)
                if mvars is not None:
                    self._logger.debug('got vars line: ' + stripline)
                    if mvars.group(1) == 'var':
                        varstable[mvars.group(2)] = substitute_vars(mvars.group(3))
                        outfile.write(line)
                    else:
                        try:
                            data = replace_vars[mvars.group(1)][mvars.group(2)]
                        except KeyError, e:
                            data = None

                        if data is not None:
                            ret = handle_data(inpname, outfile, line, stripline, data, (mvars.start(3), mvars.end(3)))
                            retfinal = ret or retfinal
                        else:
                            outfile.write(line)
                    continue
                
                mconf = regconf.match(line)
                if mconf is not None:
                    self._logger.debug('got conf line: ' + stripline)
                    try:
                        data = replace_conf[mconf.group(1)][mconf.group(2)]
                    except KeyError, e:
                        data = None

                    if data is not None:
                        ret = handle_data(inpname, outfile, line, stripline, data)
                        retfinal = ret or retfinal

                        if mconf.group(5) <> None:
                            contdata = data
                    else:
                        outfile.write(line)
                        
                    if mconf.group(5) <> None:
                        contline = True

                    continue

                self._logger.warning('parsing error at "' + stripline + '", not processing.')
                outfile.write(line)

            return retfinal

        def replace_file(name):
            newname = name + '.new'

            inpfile = open(name, 'r')
            outfile = open(newname, 'wb')

            ret = process_file(name, inpfile, outfile)

            inpfile.close()
            outfile.close()

            os.rename(newname, name)
            return ret

        snortfd = open(snort_lock_file, 'a')
        fcntl.lockf(snortfd, fcntl.LOCK_EX)

        ret = replace_file(cfgfile)

        fcntl.lockf(snortfd, fcntl.LOCK_UN)
        snortfd.close()

        for fk in filecache.keys():
            filecache[fk].close()

        filecache.clear()

        return ret

    def do_init_update(self):

        replace_vars = {
            'ipvar': {
                'TELNET_SERVERS': ('disable', None), 
                'FTP_SERVERS': ('disable', None),
                'AIM_SERVERS': ('disable', None),

                'HOME_NET': ('move', 'network')

            },
            'portvar': {
                'ORACLE_PORTS': ('disable', None),
                'FTP_PORTS': ('disable', None),
                'FILE_DATA_PORTS': ('enable', None),

                'HTTP_PORTS': ('move', 'http'),
                'SSH_PORTS':  ('move', 'ssh'),
                'SIP_PORTS':  ('move', 'sip')
            }
        }

        replace_conf = {
            'preprocessor': {
                'ftp_telnet': ('disable', None), 
                'ftp_telnet_protocol': ('disable', None),
                'rpc_decode': ('disable', None),
                'dcerpc2': ('disable', None), 
                'dcerpc2_server': ('disable', None),
                'imap': ('disable',  None),
                'pop': ('disable', None),
                'smtp': ('disable', None),
                'http_inspect':        ('move', 'http-pp'),
                'http_inspect_server': ('move', 'http-pp'),
                'ssl': ('move', 'ssl-pp'),
                'ssh': ('move', 'ssh-pp'),
                'sip': ('move', 'sip-pp')
            }
        }

        # process main file
        ret = self.do_process_config('main', replace_vars, replace_conf)

        # now set the proper values at included files
        self.do_http_port_update([])
        self.do_sshd_port_update([])
        self.do_sip_port_update([])
        self.do_home_net_update([])

        return ret

    def do_sip_port_update(self, origportlist):

        if len(origportlist) == 0:
            origportlist = [ '5060', '5061' ]

        sipvarstr = '[' + ','.join(set(origportlist)) + ']'
        sipprestr = '{ ' + ' '.join(set(origportlist)) + ' }'

        replace_vars = { 'portvar': { 'SIP_PORTS':  ('replace', { 'value': sipvarstr }) } }
        replace_conf = { 'preprocessor': { 'sip': ('replace', { 'ports': 'ports ' + sipprestr}) } }

        if not self.do_process_config('sip', replace_vars, {}):
            return False

        return self.do_process_config('sip-pp', {}, replace_conf)

    def do_http_port_update(self, origportlist):

        if len(origportlist) == 0:
            origportlist = [ '80','81' ]

        httpvarstr = '[' + ','.join(origportlist) + ']'
        httpprestr = '{ ' + ' '.join(origportlist) + ' }'

        replace_vars = { 'portvar': { 'HTTP_PORTS':  ('replace', { 'value': httpvarstr }) } }

        replace_conf = {
            'preprocessor': {
                'http_inspect_server': ('replace', { 'ports': 'ports ' + httpprestr }),
            }
        }

        if not self.do_process_config('http', replace_vars, {}):
            return False

        return self.do_process_config('http-pp', {}, replace_conf)

    def do_ssl_port_update(self, origportlist):

        if len(origportlist) == 0:
            origportlist = [ '443' ]

        sslprestr =  '{ ' + ' '.join(origportlist) + ' }'

        replace_conf = {
            'preprocessor': {
                'ssl': ('replace', { 'ports': 'ports ' + sslprestr }),
            }
        }

        return self.do_process_config('ssl-pp', {}, replace_conf)

    def do_sshd_port_update(self, origportlist):

        if len(origportlist) == 0:
            origportlist = [ '22' ]

        sshvarstr = '[' + ','.join(set(origportlist)) + ']'
        sshprestr = '{ ' + ' '.join(set(origportlist)) + ' }'

        replace_vars = { 'portvar': { 'SSH_PORTS':  ('replace', { 'value': sshvarstr }) } }
        replace_conf = { 'preprocessor': { 'ssh': ('replace', { 'ports': 'server_ports ' + sshprestr }) } }

        if not self.do_process_config('ssh', replace_vars, {}):
            return False

        return self.do_process_config('ssh-pp', {}, replace_conf)

    def do_home_net_update(self, orignetworklist):

        if len(orignetworklist) == 0:
            orignetworklist = [ '127.0.0.1/32' ]

        netstr = '[' + ','.join(set(orignetworklist)) + ']'

        replace_vars = { 'ipvar': { 'HOME_NET':  ('replace', { 'value': netstr }) } }

        return self.do_process_config('network', replace_vars, {})

    def check_snort_running(self):
        try:
            pidpath = '/var/run/'
            pidlist = os.listdir(pidpath)

            for fn in pidlist:
                if not fn.startswith('snort_'):
                    continue

                fh = open(os.path.join(pidpath, fn), 'r')
                data = fh.read().strip('\n ')
                fh.close()

                try:
                    pid = int(data)
                    os.kill(pid, 0)
                    return True
                except Exception, e:
                    pass

            return False

        except Exception, e:
            self._logger.error('unable to list pidfiles: %s' % str(e))
            return False

        return True

    def check_snort_iface(self, iface):
        try:
            pidpath = '/var/run/snort_' + iface + '.pid'

            fh = open(pidpath, 'r')
            data = fh.read().strip('\n ')
            fh.close()

            try:
                pid = int(data)
                os.kill(pid, 0)
                return True

            except Exception, e:
                return False

        except Exception, e:
            self._logger.error('unable to find snort pid: %s' % str(e))
            return False

    def execute_reload_stort(self):

        pidpath = '/var/run/'
        ctlpath = os.path.join(pidpath, 'snort-restart-control')

        pos = 0

        try:
            fh = open(ctlpath, 'a')
            fcntl.lockf(fh, fcntl.LOCK_EX)
            fh.write('x')
            pos = fh.tell()
            fcntl.lockf(fh, fcntl.LOCK_UN)
            fh.close()

            if pos <> 1:
                self._logger.info('(%d) multiple reloads, delegating to first...' % pos)

                pos1 = pos
                tout = 3
                tnow = tout

                while tnow <> 0:
                    time.sleep(2)

                    try:
                        fh = open(ctlpath, 'a')

                        fcntl.lockf(fh, fcntl.LOCK_EX)
                        fh.seek(0,os.SEEK_END)
                        pos2 = fh.tell()

                        if pos2 == 0:
                            self._logger.error('(%d) reload executed, bailing out...', pos)
                            return False

                        if pos2 == pos1:
                            tnow = tnow - 1
                        else:
                            tnow = tout

                    except Exception, e:
                        self._logger.error('(%d) reload watchdog exiting, caught exception: %s', (pos, str(e)))
                        return False

                self._logger.error('(%d) timeout on reload, assuming reload process!' % pos)

            else:
                try:
                    self._logger.info('(%d) waiting for more reloads...' % pos)
                    time.sleep(2)
                except Exception, e:
                    self._logger.error('(%d) skipping reload, caught exception: %s', (pos, str(e)))
                    raise e

            done = False
            pos1 = pos

            while True:
                fh = open(ctlpath, 'a')

                fcntl.lockf(fh, fcntl.LOCK_EX)
                fh.seek(0,os.SEEK_END)
                pos2 = fh.tell()

                if pos2 == pos1:
                    done = True
                    fh.truncate(0)

                fcntl.lockf(fh, fcntl.LOCK_UN)
                fh.close()

                if done:
                    break

                self._logger.info('(%d) got more requests, still waiting...' % (pos))

                pos1 = pos2
                time.sleep(2)

        except Exception, e:
            self._logger.error('(%d) unable to synchronize reload: %s' % (pos, str(e)))
            return False

        self._logger.info('(%d) proceeding with reload...' % pos)

        try:
            pidlist = os.listdir(pidpath)

            for fn in pidlist:
                if not fn.startswith('snort_') or not fn.endswith('.pid'):
                    continue

                fh = open(os.path.join(pidpath, fn), 'r')
                data = fh.read().strip('\n ')
                fh.close()

                try:
                    pid = int(data)
                    os.kill(pid, signal.SIGHUP)
                except Exception, e:
                    self._logger.warning('(%d) unable to send SIGHUP to %s: %s' % (pos, data, str(e)))

        except Exception, e:
            self._logger.error('(%d) unable to list pidfiles: %s' % (pos, str(e)))
            return False

        return True

    def reload_snort(self):

        pid = os.fork()

        if pid <> 0:
            return True

        self.execute_reload_stort()
        sys.exit(0)

## main() ##

snortconf = None

try:
    snortconf = SnortConfUpdater()

    if snortconf.options.debug is not None:
        snortconf._logger.setLevel(logging.DEBUG)

    if snortconf.options.running is not None:
        if snortconf.check_snort_running():
            sys.exit(0)
        sys.exit(123)

    if snortconf.options.iface is not None:
        if snortconf.check_snort_iface(snortconf.options.iface):
            sys.exit(0)
        sys.exit(123)

    anything = False
    modified = False

    if snortconf.options.init is not None:
        modified = snortconf.do_init_update()
        anything = True

    if snortconf.options.sip_port_list is not None:
        portarray = snortconf.options.sip_port_list.split(',')
        modified = snortconf.do_sip_port_update(portarray)
        anything = True

    if snortconf.options.http_port_list is not None:
        portarray = snortconf.options.http_port_list.split(',')
        modified = snortconf.do_http_port_update(portarray)
        anything = True

    if snortconf.options.ssl_port_list is not None:
        portarray = snortconf.options.ssl_port_list.split(',')
        modified = snortconf.do_ssl_port_update(portarray)
        anything = True

    if snortconf.options.ssh_port_list is not None:
        portarray = snortconf.options.ssh_port_list.split(',')
        modified = snortconf.do_sshd_port_update(portarray)
        anything = True

    if snortconf.options.home_network is not None:
        ipsarray = snortconf.options.home_network.split(',')
        modified = snortconf.do_home_net_update(ipsarray)
        anything = True

    if snortconf.options.force_reload is not None:
        snortconf.reload_snort()
        modified = False
        anything = True

    if not anything:
        snortconf.parser.print_help()
        sys.exit(1)

    if snortconf.options.skip_reload is not None:
        if modified:
            sys.exit(0)
        else:
            sys.exit(123)
    else:
        if modified:
            snortconf._logger.info('configuration modified, reloading snort...')
            snortconf.reload_snort()

except SystemExit, e:
    if hasattr(e, 'code'):
        sys.exit(e.code)
    else:
        sys.exit(0)

except:
   exc_type, exc_value, exc_traceback = sys.exc_info()
   if snortconf is not None:
       snortconf._logger.error("Unexpected exception: %s/%s\n" % (exc_type, exc_value))
       snortconf._logger.error('-'*60 + "\n")
       snortconf._logger.error(traceback.format_exc())
       snortconf._logger.error('-'*60 + "\n")
   else:
       sys.stderr.write("Unexpected exception %s/%s, aborting.\n" % (exc_type, exc_value))
   sys.exit(1)

