#!/usr/local/nsc/bin/python
# vim: tabstop=4 softtabstop=4 shiftwidth=4 textwidth=80 smarttab expandtab
"""
* Copyright (C) 2012  Sangoma Technologies Corp.
* All Rights Reserved.
*
* Author(s)
* Moises Silva <moy@sangoma.com>
*
* This code is Sangoma Technologies Confidential Property.
* Use of and access to this code is covered by a previously executed
* non-disclosure agreement between Sangoma Technologies and the Recipient.
* This code is being supplied for evaluation purposes only and is not to be
* used for any other purpose.
"""

import sys
import os
import time
import logging
import oswc
import sngpy 
import pyodbc
import ipaddr
import subprocess
import re
import signal
import errno
import ctypes
from optparse import OptionParser
from datetime import datetime
from datetime import timedelta

iptc_enabled = False
try:
    """
    iptc allows us to perform iptables modifications much faster
    without the need of spawning a whole new process for each
    modification (no need to use subprocess to execute an iptables command)
    """
    import iptc
    iptc_enabled = True
except:
    pass

class SIPSecRuleMatch(object):

    def __init__(self, name, rule, logger):
        self._name = name
        self._rule = rule
        self._logger = logger 

        self.events = []
        self.src_ip = ''
        self.profile = ''
        self.account = ''
        self.sched_event = None

        self._logger.debug("New rule match '%s'" % (name))

    def __str__(self):
        return self._name

class SIPSecAction(object):
    def __init__(self, args, rule, logger):
        self._args = args
        self._rule = rule
        self._logger = logger

    def execute(self, rule_match):
        self._execute_impl(rule_match)

    def _execute_impl(self, rule_match):
        pass

# This is useful to block script-kiddies using scanners to find vulnerable accounts
class BlockIPAction(SIPSecAction):

    def __init__(self, args, rule, logger):
        super(BlockIPAction, self).__init__(args, rule, logger) 
        self._block_time = int(args) # This is in minutes
        self._logger.debug("Created action to block IP "
                           "for %d minutes" % self._block_time)

    def _execute_impl(self, rule_match):
        BlockIPAction.perform_block(rule_match._rule._id,
                                rule_match.src_ip,
                                (self._block_time * 60))

    @staticmethod
    def perform_block(rule_id, target_ip, block_seconds):
        """
            Perform the actual IP block operation
        """
        secmon = get_secmon()
        query = """ INSERT INTO %s (rule_id, name, type, block_time, block_expiration) 
                            VALUES (?, ?, ?, ?, ?)
                """ % (secmon.blocked_objects_table)
        now = datetime.now()
        xtime = now
        smart_block_seconds = block_seconds
        xtimestr = ""

        """
            Introduce one more variable smart_block_seconds
            When smart_block_seconds is 0, block permanently;
            when smart_block_seconds is from 1 to 86399, then block seconds accordingly
        """

        if block_seconds > 0 and block_seconds < 86400:
            block_period = timedelta(0, block_seconds)
            xtime = now + block_period
            xtimestr = xtime.strftime('%Y-%m-%d %H:%M:%S')
        else:
            smart_block_seconds = 0
            xtimestr = "9999-12-31 23:59:59"

        secmon.sql_exec(query, 
                        rule_id,
                        target_ip,
                        'ip',
                        now.strftime('%Y-%m-%d %H:%M:%S'),
                        xtimestr)
        
        BlockIPAction.fw_block_ip(target_ip, smart_block_seconds, xtime)

    @staticmethod
    def unblock_ip_ex(src_addr):
        """
            Force unblocking an IP from an external process
        """
        secmon = get_secmon()
        # Force the expiration to be immediate
        query = "UPDATE %s SET block_expiration = '%s' WHERE name = ?" % (secmon.blocked_objects_table, '0000-00-00 00:00:00')
        cursor = secmon.sql_exec(query, src_addr)
        if cursor.rowcount <= 0:
            return False
        # find the secmonitor process
        proc = sngpy.Service.find_process()
        if proc is not None:
            # send a SIGHUP signal to wake up immediately
            proc.send_signal(signal.SIGHUP)
        return True

    @staticmethod
    def fw_block_ip(src_addr, seconds, xtime):
        secmon = get_secmon()
        minutes = (seconds / 60)
        secmon._logger.info("Blocking IP %s for about %d minutes (%d seconds exactly)" % (src_addr, minutes, seconds))
        if iptc_enabled:
            while True:
                try:
                    iptc.TABLE_FILTER.restart()
                    sec_chain = iptc.Chain(iptc.TABLE_FILTER, secmon.chain)

                    rule = iptc.Rule()
                    rule.src = src_addr
                    rule.target = iptc.Target(rule, 'DROP')

                    sec_chain.append_rule(rule)
                    iptc.TABLE_FILTER.close()
                    break
                except (iptc.IPTCError, iptc.XTablesError), e:
                    if ctypes.get_errno() != errno.EAGAIN:
                        secmon._logger.error("Error blocking IP %s: %s" % (src_addr, str(e)))
                        break
                    else:
                        secmon._logger.debug("Kernel consistency check failed at fw_block_ip, retrying..")
                        time.sleep(0.05)
        else:
            cmd = [secmon.iptables_bin, '--table', 'filter', '--append', secmon.chain, '-s', src_addr, '-j', 'DROP']
            secmon.cmd_exec(cmd)

        if seconds > 0:
            secmon.sched.enter(seconds, 1, BlockIPAction.unblock_ip, (src_addr, xtime))

    @staticmethod
    def unblock_ip(src_addr, xtime):
        secmon = get_secmon()
        BlockIPAction.fw_unblock_ip(src_addr)
        query = "DELETE from %s WHERE type = 'ip' AND name = '%s'" % (secmon.blocked_objects_table, src_addr)
        if xtime is not None:
            query += " AND block_expiration = '%s'" % (xtime.strftime('%Y-%m-%d %H:%M:%S'))
        secmon.sql_exec(query)

    @staticmethod
    def fw_unblock_ip(src_addr):
        secmon = get_secmon()
        secmon._logger.info("Unblocking IP %s" % (src_addr))
        if iptc_enabled:
            while True:
                try:
                    iptc.TABLE_FILTER.restart()
                    sec_chain = iptc.Chain(iptc.TABLE_FILTER, secmon.chain)

                    rule = iptc.Rule()
                    rule.src = src_addr
                    rule.target = iptc.Target(rule, 'DROP')

                    sec_chain.delete_rule(rule)
                    iptc.TABLE_FILTER.close()
                    break
                except (iptc.IPTCError, iptc.XTablesError), e:
                    if ctypes.get_errno() != errno.EAGAIN:
                        secmon._logger.error("Error unblocking IP %s: %s (%d)" % (src_addr, str(e), len(secmon.sched.queue)))
                        break
                    else:
                        secmon._logger.debug("Kernel consistency check failed at fw_unblock_ip, retrying..")
                        time.sleep(0.05)
        else:
            cmd = [secmon.iptables_bin, '--table', 'filter', '--delete', secmon.chain, '-s', src_addr, '-j', 'DROP']
            secmon.cmd_exec(cmd)

    @staticmethod
    def unblock_expired():
        secmon = get_secmon()
        now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        query = "SELECT * FROM %s WHERE block_expiration <= '%s' AND type = 'ip'" % (secmon.blocked_objects_table, now)
        cursor = secmon.sql_exec(query)
        for obj in cursor:
            if iptc_enabled:
                while True:
                    try:
                        iptc.TABLE_FILTER.restart()
                        sec_chain = iptc.Chain(iptc.TABLE_FILTER, secmon.chain)

                        rule = iptc.Rule()
                        rule.src = obj.name
                        rule.target = iptc.Target(rule, 'DROP')

                        sec_chain.delete_rule(rule)
                        iptc.TABLE_FILTER.close()

                        query = "DELETE FROM %s WHERE blocked_object_id = %d" % (secmon.blocked_objects_table, obj.blocked_object_id)
                        cursor = secmon.sql_exec(query)
                        break
                    except (iptc.IPTCError, iptc.XTablesError), e:
                        if ctypes.get_errno() != errno.EAGAIN:
                            secmon._logger.error("Error unblocking IP %s: %s (%d)" % (obj.name, str(e), len(secmon.sched.queue)))
                            break
                        else:
                            secmon._logger.debug("Kernel consistency check failed at unblock_expired, retrying..")
                            time.sleep(0.05)
            else:
                cmd = [secmon.iptables_bin, '--table', 'filter', '--delete', secmon.chain, '-s', obj.name, '-j', 'DROP']
                rc = secmon.cmd_exec(cmd)
                if rc == 0:
                    query = "DELETE FROM %s WHERE blocked_object_id = %d" % (secmon.blocked_objects_table, obj.blocked_object_id)
                    cursor = secmon.sql_exec(query)


    @staticmethod
    def block_all():
        secmon = get_secmon()
        now = datetime.now()
        now_str = now.strftime('%Y-%m-%d %H:%M:%S')
        # Read from database which IPs are supposed to be blocked
        dbquery = "SELECT * FROM %s WHERE block_expiration > '%s' AND type = 'ip'" % (secmon.blocked_objects_table, now_str)
        cursor = secmon.sql_exec(dbquery)

        # Block each object as needed
        for obj in cursor:
            secmon._logger.debug("Re-blocking IP %s, block_expiration=%s", obj.name, obj.block_expiration)
            xtime = datetime(*(time.strptime(str(obj.block_expiration), '%Y-%m-%d %H:%M:%S')[0:6]))
            period = xtime - now
            BlockIPAction.fw_block_ip(obj.name, period.seconds, xtime)

class Filter(object):
    def __init__(self, expr, rule, logger):
        self._expr = expr 
        self._rule = rule
        self._logger = logger

    def match(self, obj, *args):
        if len(args) == 0:
            return self._match_impl(obj)
        else:
            return self._match_impl(obj, args)

    def _match_impl(self, obj):
        return False

class IPFilter(Filter):
    def __init__(self, expr, rule, logger):

        self._accepted_ips = dict()
        self._negated_ips = dict()
        self._accepted_networks = dict()
        self._negated_networks = dict()

        super(IPFilter, self).__init__(expr, rule, logger)

        self._logger.debug("Building IP filter from expr %s" % (expr))
        expr_elements = expr.split(',')
        for expr_str in expr_elements:
            ipstr = expr_str.strip()
            negated = False

            if ipstr[0] == '!':
                negated = True
                ipstr = ipstr[1:]

            index = ipstr.find('/')
            if index != -1:
                net = ipaddr.IPNetwork(ipstr)
                if negated:
                    self._negated_networks[ipstr] = net
                else:
                    self._accepted_networks[ipstr] = net
            else:
                addr = ipaddr.IPAddress(ipstr)
                if negated:
                    self._negated_ips[ipstr] = addr
                else:
                    self._accepted_ips[ipstr] = addr

    def _match_impl(self, obj):
        match_status = False
        addr = None
        addr_str = None
        try:
            addr_str = str(obj)
            addr = ipaddr.IPAddress(addr_str) 
        except ValueError:
            self._logger.error("Provided match object %s is not a valid IP address" % str(obj))
            return False
        
        if addr_str in self._accepted_ips:
            return True

        if (len(self._negated_ips) > 0 
            and addr_str not in self._negated_ips):
            return True

        for k, net in self._negated_networks.items():
            if net.Contains(addr):
                # the address is contained in one of the
                # negated networks, therefore is not a
                # match
                # ie, a list of negated networks
                # (!192.168.1.0/24,!192.168.2.0/24)
                # means the intention is to match IPs that DO NOT
                # belong to ANY of those 2 networks
                match_status = False
                break
            else:
                match_status = True # potential match

        for k, net in self._accepted_networks.items():
            if net.Contains(addr):
                return True

        return match_status

class UserAgentFilter(Filter):

    regex_id = 'regex='

    def __init__(self, expr, rule, logger):
        self.user_agent_str = None
        self.regex = None
        self.regex_str = ''

        super(UserAgentFilter, self).__init__(expr, rule, logger)

        self._logger.debug("Building user agent filter from expr %s" % (expr))

        expr.strip()
        if self.regex_id == expr[0:len(self.regex_id)]:
            self.regex_str = expr[len(self.regex_id):]
            self.regex = re.compile(self.regex_str, re.IGNORECASE)
        else:
            self.user_agent_str = expr

    def _match_impl(self, obj):
        agent = str(obj)

        if self.regex is not None:
            match = self.regex.match(agent)
            if match is not None:
                return True
            else:
                return False

        if (self.user_agent_str is not None
            and agent == self.user_agent_str):
            return True

        return False

class AccountFilter(Filter):

    regex_id = 'regex='

    def __init__(self, expr, rule, logger):
        self._accepted_accounts = dict()
        self._negated_accounts = dict()
        self._unknown = False
        self.regex = None
        self.regex_str = ''

        super(AccountFilter, self).__init__(expr, rule, logger)

        self._logger.debug("Building account filter from expr %s" % (expr))

        expr.strip()
        if expr == 'unknown':
            self._unknown = True
        elif self.regex_id == expr[0:len(self.regex_id)]:
            self.regex_str = expr[len(self.regex_id):]
            self.regex = re.compile(self.regex_str, re.IGNORECASE)
        else:
            expr_elements = expr.split(',')
            for expr_str in expr_elements:
                account = expr_str.strip()
                negated = False
                if account[0] == '!':
                    negated = True
                    account = account[1:]
                    self._negated_accounts[account] = account
                else:
                    self._accepted_accounts[account] = account
            
    def _match_impl(self, obj, invalid):
        account = str(obj)

        if self.regex is not None:
            match = self.regex.match(account)
            if match is not None:
                return True
            else:
                return False

        if self._unknown and invalid:
            return True

        if account in self._accepted_accounts:
            return True

        if (len(self._negated_accounts) > 0
             and account not in self._negated_accounts):
            return True

        return False

class SIPSecRule(object):
    def __init__(self, dbreg, logger):
        self._logger = logger
        self._id = int(dbreg.rule_id)
        self._name = dbreg.name
        self._failed_attempts = dbreg.failed_attempts
        self._time_frame = (dbreg.time_frame * 60)

        if dbreg.profile_filter is None:
            self._profile_filter = ''
        else:
            self._profile_filter = dbreg.profile_filter

        self._actions = []
        self._ip_filter = None
        self._account_filter = None
        self._user_agent_filter = None
        self._matches = dict()
    
        # build filters
        if (dbreg.src_ip_filter_expr is not None
            and len(dbreg.src_ip_filter_expr) > 0):
            self._ip_filter = IPFilter(dbreg.src_ip_filter_expr, self, self._logger)
        
        if (dbreg.account_filter_expr is not None
             and len(dbreg.account_filter_expr)) > 0:
            self._account_filter = AccountFilter(dbreg.account_filter_expr, self, self._logger)

        if (dbreg.user_agent_filter_expr is not None
             and len(dbreg.user_agent_filter_expr)) > 0:
            self._user_agent_filter = UserAgentFilter(dbreg.user_agent_filter_expr, self, self._logger)

        # build actions
        self._build_actions(dbreg.action_expr)

    def _build_actions(self, expr):
        self._logger.debug("Building actions from expr %s" % (expr))
        action_list = expr.split(',')
        for action_str in action_list:
            action_str = action_str.strip()
            if action_str == '':
                self._logger.debug("Ignoring empty action for rule '%s'" % (self))
                continue
            index = action_str.find('=')
            if index == -1:
                self._logger.warning("Invalid action '%s' for rule '%s'" 
                                    % (action_str, self))
                continue
            action_name = action_str[:index]
            action_args = action_str[index+1:]
            if action_name == 'block_ip':
                action = BlockIPAction(action_args, self, self._logger)
            else:
                self._logger.warning("Invalid action '%s' for rule '%s'" 
                                    % (action_str, self))
                continue
            self._actions.append(action)

    def match(self, event):
        secmon = get_secmon()
        status = event.get_header('auth-status')

        # Only match failure attempts
        if status and status == 'ok':
            return False

        hash = "%d/%s" % (self._id, self._name)
        profile = event.get_header('profile-name')
        src_ip = event.get_header('network-ip')
        account = event.get_header('username') + "@" + event.get_header('realm')
        user_agent = event.get_header('user-agent')
        reason = event.get_header('fail-reason')
        invalid_user = False
        if reason == 'invalid-user':
            invalid_user = True

        # Match profile 
        if len(self._profile_filter) > 0 and profile != self._profile_filter:
            #self._logger.debug("Profile %s did not match" % (profile))
            return False

        # Match IP
        if (self._ip_filter 
            and not self._ip_filter.match(src_ip)):
            #self._logger.debug("IP %s did not match" % (src_ip))
            return False

        # Match account
        if (self._account_filter 
            and not self._account_filter.match(account, invalid_user)):
            #self._logger.debug("Account %s did not match" % (account))
            return False

        # Match user agent
        if (user_agent is not None
            and self._user_agent_filter is not None
            and not self._user_agent_filter.match(user_agent)):
            #self._logger.debug("User agent %s did not match" % (user_agent))
            return False

        if len(self._profile_filter) > 0:
            hash += '/' + profile

        # Always include src ip in the hash
        hash += '/' + src_ip

        if self._account_filter:
            if self._account_filter.regex is not None:
                hash += '/' + self._account_filter.regex_str
            elif invalid_user:
                hash += '/' + '@invalid-user@'
            else:
                hash += '/' + account

        if hash in self._matches:
            rule_match = self._matches[hash]
            # cancel previous safety timer
            secmon.sched.cancel(rule_match.sched_event)
        else:
            rule_match = SIPSecRuleMatch(hash, self, self._logger)
            rule_match.profile = profile
            rule_match.src_ip = src_ip
            rule_match.account = account
            self._matches[hash] = rule_match

        # Save event and its capture time
        curtime = time.time()
        event.secmon_time = curtime
        rule_match.events.append(event)

        # schedule an expiration timer for this match in case no events match anymore within the time frame
        rule_match.sched_event = secmon.sched.enter(self._time_frame, 1, self._delete_match, (rule_match, ))

        # clear up any events that are no longer valid considering the time frame
        expired_events = []
        for regev in rule_match.events:
            diff = (curtime - regev.secmon_time)
            if diff < 0:
                # Attempt to fix obviously broken time (time on system changed)
                regev.secmon_time = curtime
            if diff > self._time_frame:
                expired_events.append(regev)
        for expev in expired_events:
            rule_match.events.remove(expev)

        # if not enough attempts have occurred within the valid time frame, we're done here
        if len(rule_match.events) < self._failed_attempts:
            return True

        # Enough events collected now, trigger the actions
        self._logger.info("Executing actions for rule match %s from IP %s" % (hash, src_ip))
        for action in self._actions:
            action.execute(rule_match)

        # get rid of the match
        secmon.sched.cancel(rule_match.sched_event)
        self._delete_match(rule_match)
        
        return True

    def _delete_match(self, match):
        hash = str(match)
        self._logger.debug("Deleting rule match '%s'" % (hash))
        del self._matches[hash]

    def __str__(self):
        return self._name


class SIPSecMonitor(sngpy.DBService):

    _service_name = 'sipsecmon'

    def __init__(self, logger):

        self._dbconf = dict()
        self._dbconf_params = ['connection-string', 'table-prefix']
        self._dbconn = None

        self._swconf = dict()
        self._swconf_params = ['connection-string']

        self._sleep_interval = 1

        self.blocked_objects_table = ''
        self.rules_table = ''
        self.sched = None

        self._rules = []

        self._oswc_conn = None

        self.chain = 'sip_security'
        self.iptables_bin = '/sbin/iptables'

        self._null = None

        self._logger = logger

        formatter = logging.Formatter(self._logformat)

        console_handler = logging.StreamHandler()
        console_handler.setFormatter(formatter)
        self._logger.addHandler(console_handler)

        self.sched = sngpy.Scheduler(time.time, time.sleep)

        self._null = open(os.devnull, 'w')

        # custom SIPSecMon options should be added here via parser.add_option()
        self.parser = OptionParser()
        self.parser.add_option("", "--unblock-ip", dest="unblock_ip",
                            help="Unblock the given IP address")
        self.parser.add_option("", "--init-database", action="store_true",
                            dest="init_database",
                            help="Initialize Database")

        super(SIPSecMonitor, self).__init__()

        # handle sighup to reload rules and block/unblock objects
        self.reload = False
        def sigreload(signum, frame):
            secmon = get_secmon()
            secmon.reload = True
        signal.signal(signal.SIGHUP, sigreload)


    def _parse_db(self, params):
        self._logger.debug("Reading database parameters")
        for p in params:
            self._logger.debug("database: %s=%s" % (p.attrib['name'], p.attrib['value']))
            if p.attrib['name'] in self._dbconf_params:
                self._dbconf[p.attrib['name']] = p.attrib['value']
            else:
                raise ValueError, "Unknown database XML parameter %s" % p.attrib['name']
        if 'connection-string' not in self._dbconf:
                raise ValueError, "Missing database connection-string XML parameter"
        if 'table-prefix' not in self._dbconf:
                self._dbconf['table-prefix'] = ''
        self.rules_table = self._dbconf['table-prefix'] + "rules"
        self.blocked_objects_table = self._dbconf['table-prefix'] + "blocked_objects"

    def _parse_switch(self, params):
        self._logger.debug("Reading switch parameters")
        for p in params:
            self._logger.debug("switch: %s=%s" % (p.attrib['name'], p.attrib['value']))
            if p.attrib['name'] in self._swconf_params:
                self._swconf[p.attrib['name']] = p.attrib['value']
            else:
                raise ValueError, "Unknown switch XML parameter %s" % p.attrib['name']
        if 'connection-string' not in self._swconf:
                raise ValueError, "Missing switch connection-string XML parameter"
                
    def configure(self):
        ret = True
        try:
            tree = super(SIPSecMonitor,self).configure()

            dbconf = tree.find('database')
            if dbconf is None: 
                raise ValueError, "Missing <database> configuration"
            else:
                params = self._get_params(dbconf)
                self._parse_db(params)

            swconf = tree.find('switch')
            if swconf is None:
                raise ValueError, "Missing <switch> configuration"
            else:
                params = self._get_params(swconf)
                self._parse_switch(params)

        except:
            self._logger.critical("Failed to configure service")
            exc_type, exc_value, exc_traceback = sys.exc_info()
            self._print_exception(exc_type, exc_value, exc_traceback)
            ret = False

        return ret

    def _db_init(self):
        try:
            # create rules table
            dbquery = "CREATE TABLE IF NOT EXISTS `%s` ("  """
                      `rule_id` INT UNSIGNED NOT NULL AUTO_INCREMENT ,
                      `name` VARCHAR(255) NULL ,
                      `failed_attempts` INT UNSIGNED NOT NULL ,
                      `time_frame` INT UNSIGNED NOT NULL ,
                      `profile_filter` TEXT NULL ,
                      `src_ip_filter_expr` TEXT NULL ,
                      `account_filter_expr` TEXT NULL ,
                      `user_agent_filter_expr` TEXT NULL ,
                      `action_expr` TEXT NULL ,
                      `comments` TEXT NULL ,
                      PRIMARY KEY (`rule_id`) ,
                      UNIQUE INDEX `rule_id_UNIQUE` (`rule_id` ASC) )
                   ENGINE = InnoDB;
                   """ % (self.rules_table)

            self.sql_exec(dbquery)

            # create blocked objects table
            dbquery = "CREATE TABLE IF NOT EXISTS `%s` (" """
                      `blocked_object_id` INT UNSIGNED NOT NULL AUTO_INCREMENT ,
                      `rule_id` INT UNSIGNED NOT NULL ,
                      `name` VARCHAR(255) NOT NULL ,
                      `type` VARCHAR(10) NOT NULL ,
                      `block_time` DATETIME NOT NULL ,
                      `block_expiration` DATETIME NOT NULL ,
                      PRIMARY KEY (`blocked_object_id`) ,
                      UNIQUE INDEX `blocked_object_id_UNIQUE` (`blocked_object_id` ASC) ,
                      INDEX `rule_id` (`rule_id` ASC))
                   ENGINE = InnoDB;
                   """ % (self.blocked_objects_table)
            self.sql_exec(dbquery)
        except:
            self._logger.error("Failed to initialize database")
            raise

    def _load_rules(self):
        try:
            dbquery = "SELECT * FROM %s" % (self.rules_table)
            cursor = self.sql_exec(dbquery)
            rulecnt = 0
            for rule_reg in cursor:
                rule = SIPSecRule(rule_reg, self._logger)
                self._rules.append(rule)
                rulecnt = rulecnt + 1
            self._logger.info("Loaded %d rules" % (rulecnt))
        except:
            self._logger.error("Failed to load rules")
            raise

    def cmd_exec(self, cmd, errcheck=True):
        self._logger.debug("Executing command %s" % (' '.join(cmd)))
        rc = subprocess.call(cmd, stdout=self._null, stderr=self._null)
        if errcheck and rc:
            self._logger.error("Command %s returned %d" % (' '.join(cmd), rc))
        return rc

    def _remove_input_rule(self):
        if iptc_enabled:
            while True:
                try:
                    iptc.TABLE_FILTER.restart()
                    input_chain = iptc.Chain(iptc.TABLE_FILTER, 'INPUT')
                    rule = iptc.Rule()
                    rule.target = iptc.Target(rule, self.chain)
                    input_chain.delete_rule(rule)
                    iptc.TABLE_FILTER.close()
                    break
                except (iptc.IPTCError, iptc.XTablesError):
                    if ctypes.get_errno() != errno.EAGAIN:
                        break # pass
                    else:
                        secmon._logger.debug("Kernel consistency check failed at remove_input_rule, retrying..")
                        time.sleep(0.05)
        else:
            cmd = [self.iptables_bin, '--table', 'filter', '--delete', 'INPUT', '-j', self.chain]
            self.cmd_exec(cmd, False)

    def _flush_chain(self):
        if iptc_enabled:
            while True:
                try:
                    iptc.TABLE_FILTER.restart()
                    sec_chain = iptc.Chain(iptc.TABLE_FILTER, self.chain)
                    sec_chain.flush()
                    iptc.TABLE_FILTER.delete_chain(sec_chain)
                    iptc.TABLE_FILTER.close()
                    break
                except (iptc.IPTCError, iptc.XTablesError):
                    if ctypes.get_errno() != errno.EAGAIN:
                        break # pass
                    else:
                        secmon._logger.debug("Kernel consistency check failed at flush_chain, retrying..")
                        time.sleep(0.05)
        else:
            cmd = [self.iptables_bin, '--table', 'filter', '--flush', self.chain]
            self.cmd_exec(cmd, False)
            cmd = [self.iptables_bin, '--table', 'filter', '--delete-chain', self.chain]
            self.cmd_exec(cmd, False)

    def _create_chain(self):
        if iptc_enabled:
            while True:
                try:
                    iptc.TABLE_FILTER.restart()
                    sec_chain = iptc.Chain(iptc.TABLE_FILTER, self.chain)
                    iptc.TABLE_FILTER.create_chain(sec_chain)

                    input_chain = iptc.Chain(iptc.TABLE_FILTER, 'INPUT')
                    rule = iptc.Rule()
                    rule.target = iptc.Target(rule, self.chain)
                    input_chain.insert_rule(rule)
                    iptc.TABLE_FILTER.close()
                    break
                except (iptc.IPTCError, iptc.XTablesError), e:
                    if ctypes.get_errno() != errno.EAGAIN:
                        self._logger.error("Failed creating %s chain: %s" % (self.chain, str(e)))
                        raise
                    else:
                        secmon._logger.debug("Kernel consistency check failed at create_chain, retrying..")
                        time.sleep(0.05)
        else:
            cmd = [self.iptables_bin, '--table', 'filter', '--new-chain', self.chain]
            self.cmd_exec(cmd)

            # Add a rule to the INPUT chain to jump to our newly created chain
            cmd = [self.iptables_bin, '--table', 'filter', '--insert', 'INPUT', '-j', self.chain]
            self.cmd_exec(cmd)

    def _setup_blocked_objects(self):
        try:
            self._logger.info("Setting up blocked objects (iptc=%s)" % str(iptc_enabled))
            # Remove our existing INPUT rule (if any)
            self._remove_input_rule()

            # Flush and delete our iptables chain (if any)
            self._flush_chain()

            # Create a brand new chain
            self._create_chain()

            # Make sure all blocked obj are still blocked and 
            # unblock the expired ones
            BlockIPAction.block_all()
            BlockIPAction.unblock_expired()
        except Exception, e:
            self._logger.error("Failed to setup blocked objects: %s" % (str(e)))
            raise

    def unblock_ip(self, ipaddr):
        self._db_connect()
        return BlockIPAction.unblock_ip_ex(ipaddr)

    def init_database(self):
        # connection to the database takes care of initialization
        try:
            self._db_connect()
            return True
        except:
            return False

    def _housekeeping(self):
        self.sched.fast_run()
        if self.reload:
            self._reload()

    def _connect(self):
        # Connection loop (wait until we can connect or the daemon is stopped)
        self._oswc_conn = None
        connection_string = self._swconf['connection-string']
        while self.daemon_alive:
            self._housekeeping()
            # connect to the event provider
            if self._oswc_conn is None:
                self._oswc_conn = oswc.create_connection(connection_string, logger=self._logger)
                if self._oswc_conn is None:
                    self._logger.error("Failed to create connection to %s" % (connection_string))
                    time.sleep(self._sleep_interval)
                    continue
                listener = RegistrationListener(self._rules, self._logger)
                listener.set_filter(['REGISTER_ATTEMPT'])
                self._oswc_conn.add_event_listener(listener)

                listener = SIPLimitsListener(self._logger)
                listener.set_filter(['SIP_LIMIT_EXCEEDED'])
                self._oswc_conn.add_event_listener(listener)

            if not self._oswc_conn.connect():
                self._logger.debug("Failed to connect to %s" % (connection_string))
                time.sleep(self._sleep_interval)
                continue

            self._logger.info("Connected to %s" % (connection_string))
            break

    def _reload(self):
        self.reload = False
        self._logger.info("Forcing unblock of expired objects")
        """
            Ideally we should also remove scheduled entries
            otherwise the scheduled callback will trigger
            and will attempt to unblock an IP which is already
            unblocked, or worst it will unblock an IP entry
            that a different rule triggered.
            This is not a major issue though since typically
            if a given IP was unblocked manually it is because
            it is trusted and the fact that was blocked again
            is just a rule mis-configuration or something alike
        """
        BlockIPAction.unblock_expired()

    def _db_connect(self):
        self._dbconn = None
        try:
            self._dbconn = pyodbc.connect(self._dbconf['connection-string'])
        except Exception, e:
            self._logger.error("Failed to connect to the database %s: %s" % (self._dbconf['connection-string'], str(e)))
            raise

        # initialize the database
        self._db_init()

    def run(self):

        super(SIPSecMonitor, self).run()

        # connect to the database
        self._db_connect()

        # Load the rules
        self._load_rules()

        # Verify blocked objects are still blocked and unblock any expired registers
        self._setup_blocked_objects()

        # Connect to the event source
        self._connect()

        self._logger.info("%s is now running" % (self._service_name))

        # Event loop
        while self.daemon_alive:
            try:
                self._housekeeping()
                sched_sleep = self.sched.next_event_time_delta()
                if sched_sleep is 0:
                    """
                    You'd think we could run a housekeeping
                    if sched_sleep is zero, but apparently
                    sometimes time goes back and we don't
                    want to call receive_event with 0 as
                    that seems to (oddly) block, just wait
                    one more second to be on the safe side
                    """
                    sched_sleep = 1;

                e = self._oswc_conn.receive_event(timeout=sched_sleep*1000)
                if e is not None and str(e) == 'SERVER_DISCONNECTED':
                    self._logger.info("Server connection lost")
                    self._connect()

            except KeyboardInterrupt:
                self._logger.info("Stopping %s. User aborted." % 
                    (self._service_name))
                break
            except:
                exc_type, exc_value, exc_traceback = sys.exc_info()
                self._print_exception(exc_type, exc_value, exc_traceback)
                time.sleep(self._sleep_interval)

        self._logger.info("%s is now terminating" % (self._service_name))


class RegistrationListener(oswc.EventListener):

    def __init__(self, rules, logger):
        self._rules = rules 
        self._logger = logger
        super(oswc.EventListener, self).__init__()

    def on_event(self, event):
        self._logger.debug("Received registration event %s\n" % event)
        for rule in self._rules:
            rule.match(event)

class SIPLimitsListener(oswc.EventListener):

    def __init__(self, logger):
        self._logger = logger
        super(oswc.EventListener, self).__init__()

    def on_event(self, event):
        """
        Someone exceeded a limit, we'll block their IP
        for the limit seconds to get them back inline
        with the required limits. Ideally we would block
        for only the remaining time of the limit period, but
        for now we block for a full period
        """
        self._logger.warning("Received sip limits event for sip agent '%s' from %s (usage exceeded %s requests per %s seconds, method = %s, host = %s)\n"
                        % (event.get_header('user-agent'), event.get_header('network-ip'),
                        event.get_header('limit-max'), event.get_header('limit-seconds'),
                        event.get_header('sip-method'), event.get_header('limit-host')))
        BlockIPAction.perform_block(0,
                            event.get_header('network-ip'),
                            int(event.get_header('limit-seconds')))



logger = logging.getLogger(SIPSecMonitor._service_name)
logger.setLevel(logging.DEBUG)

## main() ##

def get_secmon():
    return secmon

secmon = None
try:
    secmon = SIPSecMonitor(logger)

    if secmon.options.stop is not None:
        secmon.stop()
        sys.exit(0)

    if secmon.options.restart is not None:
        secmon.restart()
        sys.exit(0)

    if secmon.options.conf_path is None:
        # convenient way to retrieve -c option when service is running
        secmon.options.conf_path = sngpy.Service.find_conf_path()

    if secmon.options.conf_path is None:
        secmon.parser.print_help()
        secmon.parser.error("-c is required to find the configuration path")
        sys.exit(1)

    if secmon.configure() is False:
        secmon.parser.error("Failed to configure daemon using file %s" % secmon.options.conf_path)
        sys.exit(1)

    if secmon.options.unblock_ip:
        if secmon.unblock_ip(secmon.options.unblock_ip):
            sys.exit(0)
        else:
            sys.exit(1)

    if secmon.options.init_database:
        if secmon.init_database():
            sys.exit(0)
        else:
            sys.exit(1)

    # Following options (either start or run) should not be executed if the pid file exists
    if os.path.exists(secmon.pidfile):
        logger.error("Service seems to be running already, pid file %s already exists" % secmon.pidfile)
        sys.exit(1)

    # Decide whether to run in the background (Daemon mode) or foreground
    if secmon.options.start is not None:
        secmon.start()
    else:
        secmon.run()

    sys.exit(0)
except KeyboardInterrupt:
    logger.info("Received keyboard interrupt, exiting.")
    sys.exit(0)
except SystemExit, e:
    """
    We just catch this so it won't end up in the catch-all below
    but we still must raise the exception if we want python to
    exit with the provided sys.exit() return code
    """
    if not sys.stdin.isatty():
        logger.info("SystemExit status %d" % e.code)
    raise
except:
    exc_type, exc_value, exc_traceback = sys.exc_info()
    sngpy.print_exception(exc_type, exc_value, exc_traceback)
    if not sys.stdin.isatty():
        logger.error("Unexpected exception %s/%s, aborting." % exc_type % exc_value)
    sys.exit(1)

