#!/usr/bin/env python
# vim: tabstop=4 softtabstop=4 shiftwidth=4 textwidth=80 smarttab expandtab
"""
* Copyright (C) 2012  Sangoma Technologies Corp.
* All Rights Reserved.
*
* Author(s)
* Johnny Ma <jma@sangoma.com>
*
* This code is Sangoma Technologies Confidential Property.
* Use of and access to this code is covered by a previously executed
* non-disclosure agreement between Sangoma Technologies and the Recipient.
* This code is being supplied for evaluation purposes only and is not to be
* used for any other purpose.
"""

import sys
import os
import re
import shutil
import logging
import fcntl
import traceback
from optparse import OptionParser

snort_lock_file = sys.argv[0]

class SnortConfUpdater:

    def __init__(self):

        self._script_name = 'snortconfupdater'

        self._null = None

        self._logformat = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        self._logger = logging.getLogger(self._script_name)
        self._logger.setLevel(logging.DEBUG)

        formatter = logging.Formatter(self._logformat)

        console_handler = logging.StreamHandler()
        console_handler.setFormatter(formatter)
        self._logger.addHandler(console_handler)

        self._null = open(os.devnull, 'w')

        # custom SnortConfUpdater options should be added here via parser.add_option()
        self.parser = OptionParser()
        self.parser.add_option("", "--init", dest="init_conf", action='store_true',
                        help="Initialize Snort Configuration")
        self.parser.add_option("", "--ports", dest="sip_port_list",
                        help="Update Snort Configuration SIP Ports")

        (self.options, args) = self.parser.parse_args()
        self._sorted_port_list = []

    def _add_to_list(self, str_to_add):

        if str_to_add not in self._sorted_port_list:
            self._sorted_port_list.append(str_to_add)

    def _sort_and_unique(self, portlist):

        portlist.sort()

        for element in portlist:
            self._add_to_list(element.strip())

    def do_init_update(self):

        snortfd = open(snort_lock_file, 'a')
        fcntl.lockf(snortfd, fcntl.LOCK_EX)

        cfgfile = '/etc/snort.conf'
        tmpfile = cfgfile + '.new'

        cfgreader = open(cfgfile, 'r')
        outwriter = open(tmpfile, 'wb')

        reghome = re.compile('^ipvar HOME_NET .+')
        cfghome = '/etc/snort/home-network.conf'

        if not os.access(cfghome, os.F_OK):
            try:
                fh = open(cfghome, 'wb')
                fh.write('#### Auto-generated file from ' + sys.argv[0] + ' ####\n\n')
                fh.write('ipvar HOME_NET [127.0.0.1/32]\n')
                fh.close()
            except:
                pass

        for line in cfgreader.readlines():
            # Job 1
            if reghome.match(line):
                outwriter.write('include ' + cfghome + '\n')
                outwriter.write('#' + line)
            else:
                outwriter.write(line)

        cfgreader.close()
        outwriter.close()

        os.rename(tmpfile, cfgfile)

        fcntl.lockf(snortfd, fcntl.LOCK_UN)
        snortfd.close()

    def do_port_update(self, portlist):

        self._sort_and_unique(portlist)

        snortfd = open(snort_lock_file, 'a')
        fcntl.lockf(snortfd.fileno(), fcntl.LOCK_EX)

        infile  = "/etc/snort.conf"
        outfile = "/etc/snort.conf.new"

        outwriter = open(outfile, 'wb')

        nextlinemodify = False

        for line in open(infile).readlines():
            # Job 1
            if re.match(r'^portvar SIP_PORTS \[[0-9, ]+\]\n', line):
                towrite = "portvar SIP_PORTS ["
                for str in self._sorted_port_list[:-1]:
                    towrite += str+","
                towrite += self._sorted_port_list[-1]+"]\n"
                outwriter.write(towrite)
            elif re.match(r'preprocessor sip', line): # Job 2
                nextlinemodify = True
                outwriter.write(line)
            elif nextlinemodify:
                towrite = "   ports { "
                for str in self._sorted_port_list[:-1]:
                    towrite += str+" "
                towrite += self._sorted_port_list[-1]+" }, \\\n"
                outwriter.write(towrite)
                nextlinemodify = False
            else:
                outwriter.write(line)

        outwriter.close()

        shutil.copyfile(outfile, infile)
        os.remove(outfile)

        fcntl.lockf(snortfd.fileno(), fcntl.LOCK_UN)
        snortfd.close()

## main() ##

snortconf = None

try:
    snortconf = SnortConfUpdater()

    if snortconf.options.init_conf is not None:
        snortconf.do_init_update()
        sys.exit(0)

    if snortconf.options.sip_port_list is not None:
        portarray = snortconf.options.sip_port_list.split(',')
        snortconf.do_port_update(portarray)
        sys.exit(0)

    snortconf.parser.print_help()
    sys.exit(1)

except SystemExit, e:
    if hasattr(e, 'code'):
        sys.exit(e.code)
    else:
        sys.exit(0)

except:
   exc_type, exc_value, exc_traceback = sys.exc_info()
   if snortconf is not None:
       snortconf._logger.error("Unexpected exception: %s/%s\n" % (exc_type, exc_value))
       snortconf._logger.error('-'*60 + "\n")
       snortconf._logger.error(traceback.format_exc())
       snortconf._logger.error('-'*60 + "\n")
   else:
       sys.stderr.write("Unexpected exception %s/%s, aborting.\n" % (exc_type, exc_value))
   sys.exit(1)

