#!/usr/bin/env python
# vim: tabstop=4 softtabstop=4 shiftwidth=4 textwidth=80 smarttab expandtab
"""
* Copyright (C) 2014  Sangoma Technologies Corp.
* All Rights Reserved.
*
* This code is Sangoma Technologies Confidential Property.
* Use of and access to this code is covered by a previously executed
* non-disclosure agreement between Sangoma Technologies and the Recipient.
* This code is being supplied for evaluation purposes only and is not to be
* used for any other purpose.
"""

import os
import re
import sys
import shutil
import logging
import logging.handlers
import fcntl
import traceback
from optparse import OptionParser
import logging
import logging.handlers
import subprocess
import signal

snort_lock_file = sys.argv[0]

class SnortConfUpdater:

    def __init__(self):

        self.config_file = {
               'main': '/etc/snort.conf',
               'http': '/etc/snort/http.conf',
                'sip': '/etc/snort/sip.conf',
                'ssh': '/etc/snort/ssh.conf',
            'network': '/etc/snort/network.conf'
        }
        
        self._script_name = 'snortconfupdater'

        self._logformat = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        self._logger = logging.getLogger(self._script_name)
        self._logger.setLevel(logging.INFO)

        formatter = logging.Formatter(self._logformat)

        log_handler = logging.handlers.RotatingFileHandler('/var/log/' + self._script_name + '.log', maxBytes=10*1024*1024, backupCount=4)
        log_handler.setFormatter(formatter)
        self._logger.addHandler(log_handler)

        # custom SnortConfUpdater options should be added here via parser.add_option()
        self.parser = OptionParser()

        self.parser.add_option("", "--init", dest="init", action='store_true',
                        help="Initialize snort configuration")
        self.parser.add_option("", "--debug", dest="debug", action='store_true',
                        help="Enable debug mode")

        self.parser.add_option("", "--home-network", dest="home_network",
                        help="Update snort home network IPs")
        self.parser.add_option("", "--ssh-ports", dest="ssh_port_list",
                        help="Update snort SSH ports")
        self.parser.add_option("", "--http-ports", dest="http_port_list",
                        help="Update snort HTTP ports")
        self.parser.add_option("", "--ports", dest="sip_port_list",
                        help="Update snort SIP ports")

        (self.options, args) = self.parser.parse_args()

    def do_process_config(self, cfgname, replace_vars, replace_conf):
        global contline
        global contdata

        cfgfile = self.config_file[cfgname]
      
        regcomm = re.compile('^((#.*\n)|([ ]*)\n)$')
        regincl = re.compile('^(include|dynamic(preprocessor|engine|detection)|output|alert) (.+)\n$')
        regvars = re.compile('^(var|ipvar|portvar)[ \t]+([A-Za-z0-9_]+)[ \t]+(.+)\n$')
        regconf = re.compile('^(config|preprocessor)[ \t]+([A-Za-z0-9_]+)(:[ \t]+([^\\\\]*))?([\\\\][ ]*)?\n$')
        regnext = re.compile('^([^\\\\]+)([\\\\][ ]*)?\n$')

        regrepl = {}
        regrepl['value'] = re.compile('(\\[.+?\\]|any)')
        regrepl['ports'] = re.compile('(server_)?ports [{] ([0-9 ]+) [}]')

        varstable = {}
        filecache = {}

        def substitute_vars(strtext):
            if len(varstable) == 0:
                self._logger.debug('variable table empty, skipping..')
                return strtext

            restr = '[$](' + '|'.join(varstable) + ')'
            self._logger.debug('substituting, regexp: ' + restr)
            rvar = re.compile(restr)
            posn = 0
            while True:
                mvar = rvar.search(strtext, posn)
                if mvar:
                    varvalue = varstable.get(mvar.group(1))
                    if varvalue:
                        self._logger.debug('found "' + mvar.group(1) + '", replacing by "' + varvalue + '"...')
                        strtext = strtext[:mvar.start()] + varvalue + strtext[mvar.end():]
                        posn = mvar.end()
                    else:
                        break
                else:
                    break
            return strtext
            
        def handle_data(inpname, outfile, line, stripline, data, value=None):
            if data[0] == 'disable':
                self._logger.info('commenting line: ' + stripline)
                outfile.write('# ' + line)

            elif data[0] == 'replace':
                self._logger.debug('checking replace at "' + stripline + '"...')
                try:
                    replaced = False
                    for rkey in data[1].keys():
                        if rkey == 'value':
                            if value:
                                self._logger.info('replacing value "' + line[value[0]:value[1]] + '" with "' + data[1][rkey] + '" at: ' + stripline)
                                newline = line[:value[0]] + data[1][rkey] + line[value[1]:]
                                outfile.write(newline)
                                replaced = True
                                break
                            else:
                                self._logger.warning('got "value" replace but empty value, skipping...')
                        else:
                            self._logger.debug('matching "' + rkey + '" expression...')
                            mrepl = regrepl[rkey].search(line)
                            if mrepl is not None:
                                self._logger.info('replacing match "' + line[mrepl.start():mrepl.end()] + '" with "' + data[1][rkey] + '" at: ' + stripline)
                                newline = line[:mrepl.start()] + data[1][rkey] + line[mrepl.end():]
                                outfile.write(newline)
                                replaced = True
                                break
                    if not replaced:
                        outfile.write(line)
                except Exception, e:
                    outfile.write(line)

            elif data[0] == 'move':
                fname = self.config_file.get(data[1])
                if fname is not None and fname == inpname:
                    self._logger.debug('line "' + stripline + '" already moved to "' + data[1] + '"...')
                    outfile.write(line)
                else:
                    self._logger.info('moving line "' + stripline + '" to "' + data[1] + '" config...')
                    outmove = filecache.get(data[1])
                    if outmove is None:
                        if fname is None:
                            self._logger.error('no config file id "' + data[1] + '" found, skipping..')
                            outmove = outfile
                        else:
                            self._logger.debug('opening file "' + fname + '" for moving data...')

                            outmove = open(fname, 'w')
                            outmove.write('#### Auto-generated from ' + sys.argv[0] + ' ####\n\n')

                            outfile.write('\n#### Modified by ' + sys.argv[0] + ' ####\n')
                            outfile.write('include ' + fname + '\n\n')

                            filecache[data[1]] = outmove

                    outmove.write(line)
            else:
                self._logger.error('unknown action "' + data[0] + '" at "' + stripline + '", skipping...')
                outmove.write(line)

        contline = False
        contdata = None

        def process_file(inpname, inpfile, outfile):
            self._logger.info('processing file ' + inpname + '...')

            global contline
            global contdata
            
            for line in inpfile.readlines():
                stripline = line.rstrip('\n')
                if contline:
                    self._logger.debug('parsing multi-line at "' + stripline + '"...')
                    mnext = regnext.match(line)
                    if mnext is not None:
                        if contdata is not None:
                            handle_data(inpname, outfile, line, stripline, contdata)
                        else:
                            outfile.write(line)
                            
                        if mnext.group(2) is None:
                            self._logger.debug('exiting multi-line mode...')
                            contline = False
                            contdata = None

                        continue
                    else:
                        self._logger.warning('multi-line parsing error at "' + stripline + '", not processing.')
                        outfile.write(line)
                        contline = False
                        contdata = None

                if regcomm.match(line):
                    outfile.write(line)
                    continue

                mincl = regincl.match(line)
                if mincl is not None:
                    if mincl.group(1) == 'include':
                        self._logger.debug('requesting include of "' + mincl.group(3) + '"...')
                        incname = substitute_vars(mincl.group(3))
                        self._logger.debug('scanning included "' + incname + '"...')
                        replace_file(incname)

                    outfile.write(line)
                    continue
                
                mvars = regvars.match(line)
                if mvars is not None:
                    self._logger.debug('got vars line: ' + stripline)
                    if mvars.group(1) == 'var':
                        varstable[mvars.group(2)] = substitute_vars(mvars.group(3))
                        outfile.write(line)
                    else:
                        try:
                            data = replace_vars[mvars.group(1)][mvars.group(2)]
                        except KeyError, e:
                            data = None

                        if data is not None:
                            handle_data(inpname, outfile, line, stripline, data, (mvars.start(3), mvars.end(3)))
                        else:
                            outfile.write(line)
                    continue
                
                mconf = regconf.match(line)
                if mconf is not None:
                    self._logger.debug('got conf line: ' + stripline)
                    try:
                        data = replace_conf[mconf.group(1)][mconf.group(2)]
                    except KeyError, e:
                        data = None

                    if data is not None:
                        handle_data(inpname, outfile, line, stripline, data)
                            
                        if mconf.group(5) <> None:
                            contdata = data
                    else:
                        outfile.write(line)
                        
                    if mconf.group(5) <> None:
                        contline = True

                    continue

                self._logger.warning('parsing error at "' + stripline + '", not processing.')
                outfile.write(line)

        def replace_file(name):
            newname = name + '.new'

            inpfile = open(name, 'r')
            outfile = open(newname, 'wb')
            
            process_file(name, inpfile, outfile)

            inpfile.close()
            outfile.close()

            os.rename(newname, name)

        snortfd = open(snort_lock_file, 'a')
        fcntl.lockf(snortfd, fcntl.LOCK_EX)

        replace_file(cfgfile)
        
        fcntl.lockf(snortfd, fcntl.LOCK_UN)
        snortfd.close()

        for fk in filecache.keys():
            filecache[fk].close()

    def do_init_update(self):

        replace_vars = {
            'ipvar': {
                'TELNET_SERVERS': ('disable', None), 
                'FTP_SERVERS': ('disable', None),
                'AIM_SERVERS': ('disable', None),

                'HOME_NET': ('move', 'network')

            },
            'portvar': {
                'ORACLE_PORTS': ('disable', None),
                'FTP_PORTS': ('disable', None),
                'FILE_DATA_PORTS': ('disable', None),

                'HTTP_PORTS': ('move', 'http'),
                'SSH_PORTS':  ('move', 'ssh'),
                'SIP_PORTS':  ('move', 'sip')
            }
        }

        replace_conf = {
            'preprocessor': {
                'ftp_telnet': ('disable', None), 
                'ftp_telnet_protocol': ('disable', None),
                'rpc_decode': ('disable', None),
                'dcerpc2': ('disable', None), 
                'dcerpc2_server': ('disable', None),
                'imap': ('disable',  None),
                'pop': ('disable', None),
                'smtp': ('disable', None),
                'http_inspect':        ('move', 'http'),
                'http_inspect_server': ('move', 'http'),
                'ssl': ('move', 'http'),
                'ssh': ('move', 'ssh'),
                'sip': ('move', 'sip')
            }
        }
        
        # process main file
        self.do_process_config('main', replace_vars, replace_conf)

        # now set the proper values at included files
        self.do_http_port_update([])
        self.do_sshd_port_update([])
        self.do_sip_port_update([])
        self.do_home_net_update([])


    def do_sip_port_update(self, origportlist):

        if len(origportlist) == 0:
            origportlist = [ '5060', '5061' ]

        sipvarstr = '[' + ','.join(set(origportlist)) + ']'
        sipprestr = '{ ' + ' '.join(set(origportlist)) + ' }'

        replace_vars = { 'portvar': { 'SIP_PORTS':  ('replace', { 'value': sipvarstr }) } }
        replace_conf = { 'preprocessor': { 'sip': ('replace', { 'ports': 'ports ' + sipprestr}) } }

        self.do_process_config('sip', replace_vars, replace_conf)

    def do_http_port_update(self, origportlist):

        if len(origportlist) == 0:
            origportlist = [ '80','81','443' ]

        httplist = set(origportlist).difference(set(['443']))
        ssllist =  set(origportlist).difference(set(['80', '81']))

        httpvarstr = '[' + ','.join(httplist) + ']'

        httpprestr = '{ ' + ' '.join(httplist) + ' }'
        sslprestr =  '{ ' + ' '.join(ssllist) + ' }'

        replace_vars = { 'portvar': { 'HTTP_PORTS':  ('replace', { 'value': httpvarstr }) } }

        replace_conf = {
            'preprocessor': {
                'http_inspect_server': ('replace', { 'ports': 'ports ' + httpprestr }),
                'ssl': ('replace', { 'ports': 'ports ' + sslprestr }),
            }
        }

        self.do_process_config('http', replace_vars, replace_conf)

    def do_sshd_port_update(self, origportlist):

        if len(origportlist) == 0:
            origportlist = [ '22' ]

        sshvarstr = '[' + ','.join(set(origportlist)) + ']'
        sshprestr = '{ ' + ' '.join(set(origportlist)) + ' }'

        replace_vars = { 'portvar': { 'SSH_PORTS':  ('replace', { 'value': sshvarstr }) } }
        replace_conf = { 'preprocessor': { 'ssh': ('replace', { 'ports': 'server_ports ' + sshprestr }) } }

        self.do_process_config('ssh', replace_vars, replace_conf)

    def do_home_net_update(self, orignetworklist):

        if len(orignetworklist) == 0:
            orignetworklist = [ '127.0.0.1/32' ]

        netstr = '[' + ','.join(set(orignetworklist)) + ']'

        replace_vars = { 'ipvar': { 'HOME_NET':  ('replace', { 'value': netstr }) } }

        self.do_process_config('network', replace_vars, {})

    def reload_snort(self):
        try:
            pidpath = '/var/run/'
            pidlist = os.listdir(pidpath)

            for fn in pidlist:
                if not fn.startswith('snort_'):
                    continue

                fh = open(os.path.join(pidpath, fn), 'r')
                data = fh.read().strip('\n ')
                fh.close()

                try:
                    pid = int(data)
                    os.kill(pid, signal.SIGHUP)
                except Exception, e:
                    self._logger.warning('unable to send SIGHUP to %s: %s' % (data, str(e)))

        except Exception, e:
            self._logger.error('unable to list pidfiles: %s' % str(e))
            return False

        return True

## main() ##

snortconf = None

try:
    snortconf = SnortConfUpdater()

    if snortconf.options.debug is not None:
        snortconf._logger.setLevel(logging.DEBUG)

    if snortconf.options.init is not None:
        snortconf.do_init_update()
    elif snortconf.options.sip_port_list is not None:
        portarray = snortconf.options.sip_port_list.split(',')
        snortconf.do_sip_port_update(portarray)
    elif snortconf.options.http_port_list is not None:
        portarray = snortconf.options.http_port_list.split(',')
        snortconf.do_http_port_update(portarray)
    elif snortconf.options.ssh_port_list is not None:
        portarray = snortconf.options.ssh_port_list.split(',')
        snortconf.do_sshd_port_update(portarray)
    elif snortconf.options.home_network is not None:
        ipsarray = snortconf.options.home_network.split(',')
        snortconf.do_home_net_update(ipsarray)
    else:
        snortconf.parser.print_help()
        sys.exit(1)

    snortconf.reload_snort()

except SystemExit, e:
    if hasattr(e, 'code'):
        sys.exit(e.code)
    else:
        sys.exit(0)

except:
   exc_type, exc_value, exc_traceback = sys.exc_info()
   if snortconf is not None:
       snortconf._logger.error("Unexpected exception: %s/%s\n" % (exc_type, exc_value))
       snortconf._logger.error('-'*60 + "\n")
       snortconf._logger.error(traceback.format_exc())
       snortconf._logger.error('-'*60 + "\n")
   else:
       sys.stderr.write("Unexpected exception %s/%s, aborting.\n" % (exc_type, exc_value))
   sys.exit(1)

