#!/usr/bin/env python
# vim: tabstop=4 softtabstop=4 shiftwidth=4 textwidth=80 smarttab expandtab
"""
* Copyright (C) 2012  Sangoma Technologies Corp.
* All Rights Reserved.
*
* Author(s)
* Johnny Ma <jma@sangoma.com>
*
* This code is Sangoma Technologies Confidential Property.
* Use of and access to this code is covered by a previously executed
* non-disclosure agreement between Sangoma Technologies and the Recipient.
* This code is being supplied for evaluation purposes only and is not to be
* used for any other purpose.
"""

import os
import re
import sys
import shutil
import logging
import fcntl
import traceback
from optparse import OptionParser
import logging
import logging.handlers
import subprocess
import signal

snort_lock_file = sys.argv[0]

config_ssh_ports    = '/etc/snort/ports-ssh.conf'
config_http_ports   = '/etc/snort/ports-http.conf'
config_home_network = '/etc/snort/home-network.conf'

class SnortConfUpdater:

    def __init__(self):

        self._script_name = 'snortconfupdater'

        self._logformat = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        self._logger = logging.getLogger(self._script_name)
        self._logger.setLevel(logging.DEBUG)

        formatter = logging.Formatter(self._logformat)

        console_handler = logging.StreamHandler()
        console_handler.setFormatter(formatter)
        self._logger.addHandler(console_handler)

        # custom SnortConfUpdater options should be added here via parser.add_option()
        self.parser = OptionParser()
        self.parser.add_option("", "--init", dest="init_conf", action='store_true',
                        help="Initialize Snort Configuration")

        self.parser.add_option("", "--home-network", dest="home_network",
                        help="Update Snort Home Network IPs")

        self.parser.add_option("", "--ssh-ports", dest="ssh_port_list",
                        help="Update Snort SSH Ports")

        self.parser.add_option("", "--http-ports", dest="http_port_list",
                        help="Update Snort HTTP Ports")

        self.parser.add_option("", "--ports", dest="sip_port_list",
                        help="Update Snort Configuration SIP Ports")

        (self.options, args) = self.parser.parse_args()

    def do_init_update(self):

        snortfd = open(snort_lock_file, 'a')
        fcntl.lockf(snortfd, fcntl.LOCK_EX)

        # init external files first
        self.do_http_port_update([])
        self.do_sshd_port_update([])
        self.do_home_net_update([])

        cfgfile = '/etc/snort.conf'
        tmpfile = cfgfile + '.new'

        cfgreader = open(cfgfile, 'r')
        outwriter = open(tmpfile, 'wb')

        reghome = re.compile('^ipvar HOME_NET .+')
        reghttp = re.compile('^portvar HTTP_PORTS .+')
        regsshd = re.compile('^portvar SSH_PORTS .+')

        for line in cfgreader.readlines():
            # Job 1
            if reghome.match(line):
                outwriter.write('include ' + config_home_network + '\n')
                outwriter.write('#' + line)
            elif reghttp.match(line):
                outwriter.write('include ' + config_http_ports + '\n')
                outwriter.write('#' + line)
            elif regsshd.match(line):
                outwriter.write('include ' + config_ssh_ports + '\n')
                outwriter.write('#' + line)
            else:
                outwriter.write(line)
            
        cfgreader.close()
        outwriter.close()

        os.rename(tmpfile, cfgfile)

        fcntl.lockf(snortfd, fcntl.LOCK_UN)
        snortfd.close()

    def do_sip_port_update(self, origportlist):

        portlist = list(set(origportlist))

        snortfd = open(snort_lock_file, 'a')
        fcntl.lockf(snortfd.fileno(), fcntl.LOCK_EX)

        infile  = "/etc/snort.conf"
        outfile = "/etc/snort.conf.new"

        outwriter = open(outfile, 'wb')

        nextlinemodify = False

        for line in open(infile).readlines():
            # Job 1
            if re.match(r'^portvar SIP_PORTS \[[0-9, ]+\]\n', line):
                towrite = "portvar SIP_PORTS [" + ','.join(portlist) + "]\n"
                outwriter.write(towrite)
            elif re.match(r'preprocessor sip', line): # Job 2
                nextlinemodify = True
                outwriter.write(line)
            elif nextlinemodify:
                towrite = "   ports { " + ','.join(portlist) + " }, \\\n"
                outwriter.write(towrite)
                nextlinemodify = False
            else:
                outwriter.write(line)

        outwriter.close()

        shutil.copyfile(outfile, infile)
        os.remove(outfile)

        fcntl.lockf(snortfd.fileno(), fcntl.LOCK_UN)
        snortfd.close()

    def do_http_port_update(self, origportlist):

        portlist = list(set(origportlist))

        configname = config_http_ports
        configtemp = configname + '.new'

        try:
            fh = open(configtemp, 'w')

            if len(portlist) == 0:
                portstr = '80,81,443'
            else:
                portstr = ','.join(portlist)

            fh.write('#### Auto-generated file from ' + sys.argv[0] + ' ####\n\n')
            fh.write('portvar HTTP_PORTS [' + portstr + ']\n')

            fh.close()

        except Exception, e:
            self._logger.error('unable to open %s: %s' % (configtemp, str(e)))
            return False

        try:
            os.rename(configtemp, configname)
        except Exception, e:
            self._logger.error('unable to replace configuration file %s: %s' % (configname, str(e)))
            return False
            
        return True

    def do_sshd_port_update(self, origportlist):

        portlist = list(set(origportlist))

        configname = config_ssh_ports
        configtemp = configname + '.new'

        try:
            fh = open(configtemp, 'w')

            if len(portlist) == 0:
                portstr = '22'
            else:
                portstr = ','.join(portlist)

            fh.write('#### Auto-generated file from ' + sys.argv[0] + ' ####\n\n')
            fh.write('portvar SSH_PORTS [' + portstr + ']\n')

            fh.close()

        except Exception, e:
            self._logger.error('unable to open %s: %s' % (configtemp, str(e)))
            return False

        try:
            os.rename(configtemp, configname)
        except Exception, e:
            self._logger.error('unable to replace configuration file %s: %s' % (configname, str(e)))
            return False
            
        return True

    def do_home_net_update(self, orignetworklist):
        networklist = list(set(orignetworklist))

        configname = config_home_network
        configtemp = configname + '.new'

        try:
            fh = open(configtemp, 'w')

            if len(networklist) == 0:
                networkstr = '127.0.0.1/32'
            else:
                networkstr = ','.join(networklist)

            fh.write('#### Auto-generated file from ' + sys.argv[0] + ' ####\n\n')
            fh.write('ipvar HOME_NET [' + networkstr + ']\n')

            fh.close()

        except Exception, e:
            self._logger.error('unable to open %s: %s' % (configtemp, str(e)))
            return False

        try:
            os.rename(configtemp, configname)
        except Exception, e:
            self._logger.error('unable to replace configuration file %s: %s' % (configname, str(e)))
            return False
            
        return True

    def reload_snort(self):
        try:
            pidpath = '/var/run/'
            pidlist = os.listdir(pidpath)

            for fn in pidlist:
                if not fn.startswith('snort_'):
                    continue

                fh = open(os.path.join(pidpath, fn), 'r')
                data = fh.read().strip('\n ')
                fh.close()

                try:
                    pid = int(data)
                    os.kill(pid, signal.SIGHUP)
                except Exception, e:
                    self._logger.warning('unable to send SIGHUP to %s: %s' % (data, str(e)))

        except Exception, e:
            self._logger.error('unable to list pidfiles: %s' % str(e))
            return False

        return True

## main() ##

snortconf = None

try:
    snortconf = SnortConfUpdater()

    if snortconf.options.init_conf is not None:
        snortconf.do_init_update()
    elif snortconf.options.sip_port_list is not None:
        portarray = snortconf.options.sip_port_list.split(',')
        snortconf.do_sip_port_update(portarray)
    elif snortconf.options.http_port_list is not None:
        portarray = snortconf.options.http_port_list.split(',')
        snortconf.do_http_port_update(portarray)
    elif snortconf.options.ssh_port_list is not None:
        portarray = snortconf.options.ssh_port_list.split(',')
        snortconf.do_sshd_port_update(portarray)
    elif snortconf.options.home_network is not None:
        ipsarray = snortconf.options.home_network.split(',')
        snortconf.do_home_net_update(ipsarray)
    else:
        snortconf.parser.print_help()
        sys.exit(1)

    snortconf.reload_snort()

except SystemExit, e:
    if hasattr(e, 'code'):
        sys.exit(e.code)
    else:
        sys.exit(0)

except:
   exc_type, exc_value, exc_traceback = sys.exc_info()
   if snortconf is not None:
       snortconf._logger.error("Unexpected exception: %s/%s\n" % (exc_type, exc_value))
       snortconf._logger.error('-'*60 + "\n")
       snortconf._logger.error(traceback.format_exc())
       snortconf._logger.error('-'*60 + "\n")
   else:
       sys.stderr.write("Unexpected exception %s/%s, aborting.\n" % (exc_type, exc_value))
   sys.exit(1)

