#!/usr/bin/python2.7
# vim: tabstop=4 softtabstop=4 shiftwidth=4 textwidth=80 smarttab expandtab
"""
* Copyright (C) 2012  Sangoma Technologies Corp.
* All Rights Reserved.
*
* Author(s)
* Moises Silva <moy@sangoma.com>
*
* This code is Sangoma Technologies Confidential Property.
* Use of and access to this code is covered by a previously executed
* non-disclosure agreement between Sangoma Technologies and the Recipient.
* This code is being supplied for evaluation purposes only and is not to be
* used for any other purpose.
"""

import sys
import os
import time
import logging
import oswc
import sngpy
import pyodbc
import ipaddr
import subprocess
import re
import signal
import errno
import pytables as iptc
from pytables.helpers import *
import ctypes
from optparse import OptionParser
from datetime import datetime
from datetime import timedelta


#######################
## Maintenance Notes ##

## About the 'type' field in blocked_objects_table ##
# blocked_objects_table was originally thought to hold both blocked ips and sip
# accounts, however, we never finished the 'account' blocking feature and
# therefore the 'type' field in the blocked_objects_table is useless
# the new expired_objects_table will temporarily hold recently expired objects
# but it does not have a type since it's not really needed and hopefully one day
# will be removed from the blocked_objects_table as well (or the account blocking feature actually implemented)

#######################

sipsecmon_logger = None

def get_logger():
    return sipsecmon_logger

class SIPSecmonChain(object):
    def __init__(self, chain_name, logger):
        global sipsecmon_logger
        sipsecmon_logger = logger

        self.logger = logger
        self.chain_name = chain_name

        # ipv4 and ipv6 have different tables
        self.table = {
            4: iptc.Table(iptc.Table.FILTER, autocommit=False),
            6: iptc.Table6(iptc.Table6.FILTER, autocommit=False)
        }

        # and different constructors for rules
        self.rule_cls = {
            4: iptc.Rule,
            6: iptc.Rule6
        }

        # and seperate chains
        self.chains = {
            4: self._setup_chain(4),
            6: self._setup_chain(6)
        }

    def create_rule(self, src_addr):
        addr = ipaddr.IPAddress(src_addr)
        return addr, self.rule_cls[addr.version]()

    @iptc_command(get_logger)
    def append_rule(self, src_addr):
        addr, rule = self.create_rule(src_addr)
        yield self.chains[addr.version].table

        rule.src = str(addr)
        rule.target = rule.create_target('DROP')
        self.chains[addr.version].append_rule(rule)

    @iptc_command(get_logger)
    def delete_rule(self, src_addr):
        addr, rule = self.create_rule(src_addr)
        yield self.chains[addr.version].table

        rule.src = str(addr)
        rule.target = rule.create_target('DROP')
        self.chains[addr.version].delete_rule(rule)

    @iptc_command(get_logger)
    def _setup_chain(self, version):
        table = self.table[version]
        yield table

        input_chain = iptc.Chain(table, 'INPUT')
        rule = self.rule_cls[version]()

        if not table.is_chain(self.chain_name[version]):
            chain = table.create_chain(self.chain_name[version])
        else:
            rule.target = rule.create_target(self.chain_name[version])
            try:
                input_chain.delete_rule(rule)
            except (iptc.IPTCError, iptc.XTablesError):
                pass
            chain = iptc.Chain(table, self.chain_name[version])

        rule.target = rule.create_target(self.chain_name[version])
        input_chain.insert_rule(rule)

        chain.flush()

        iptc_return(chain)


def perform_block(rule_id, target_ip, block_seconds):
    """
    Perform the actual IP block operation
    """
    secmon = get_secmon()

    # Check first if this IP is already in the db, ignore the block request
    # if it is, it might be a dup if we were too slow to block the ip and
    # the events come too quick (slight race)
    query = "SELECT * FROM {} WHERE name = ?".format(secmon.blocked_objects_table)
    cursor = secmon.sql_exec(query, target_ip)
    row = cursor.fetchone()
    if row:
        secmon._logger.warning("Not blocking IP {} again, block will expire on {}".format(row.name, row.block_expiration))
        return

    # Up to one day blocking time, considered it a permanent block otherwise
    # (where 0 seconds means permanent)
    if block_seconds >= 86400:
        block_seconds = 0

    # Do some clean up of old records in the expiration cache table
    query = """ DELETE FROM {} WHERE
                DATE_ADD(last_expiration, INTERVAL ? second) < NOW()
                AND name NOT IN (SELECT name FROM {})
            """.format(secmon.expired_objects_table, secmon.blocked_objects_table)
    secmon.sql_exec(query, secmon.expiration_cache_period)

    # If the IP is still/already in the expired objects cache then update the block count
    # and block this IP for twice as much as the previous time
    # We only do this check for non-permanent blocks
    if block_seconds > 0:
        query = "SELECT * FROM {} WHERE name = ?".format(secmon.expired_objects_table)
        cursor = secmon.sql_exec(query, target_ip)
        row = cursor.fetchone()
        if row:
            query = "UPDATE {} SET block_count = block_count + 1 WHERE name = ?".format(secmon.expired_objects_table)
            secmon.sql_exec(query, target_ip)
            block_seconds = block_seconds * (2 ** row.block_count)
            if block_seconds >= 86400:
                block_seconds = 0
                secmon._logger.warning("IP {} will be blocked permanently due to {} repeated offences".format(row.name, row.block_count))
            else:
                secmon._logger.warning("IP {} will be blocked for {} seconds due to {} repeated offences".format(row.name, block_seconds, row.block_count))

    # Insert the actual block record
    now = datetime.now()

    if block_seconds > 0:
        block_period = timedelta(0, block_seconds)
        xtime = now + block_period
    else:
        # Permanent blocking
        xtime = datetime.max

    query = """ INSERT INTO {} (rule_id, name, type, block_time, block_expiration)
                        VALUES (?, ?, ?, ?, ?)
            """.format(secmon.blocked_objects_table)
    secmon.sql_exec(query,
                    rule_id,
                    target_ip,
                    'ip',
                    now.strftime('%Y-%m-%d %H:%M:%S'),
                    xtime.strftime('%Y-%m-%d %H:%M:%S'))

    # Do the actual firewall block
    fw_block_ip(target_ip, block_seconds)


def unblock_ip_ex(src_addr):
    """
    Force unblocking an IP from an external process
    """
    secmon = get_secmon()
    # Force the expiration to be immediate
    query = "UPDATE {} SET block_expiration = '0000-00-00 00:00:00' WHERE name = ?".format(secmon.blocked_objects_table)
    cursor = secmon.sql_exec(query, src_addr)
    if cursor.rowcount <= 0:
        secmon._logger.warning("Could not force unblock of IP {}: IP not found or record already expired".format(src_addr))
        return False

    # Since this is a forced update, delete directly from the expired_objects table as well
    # to make sure a subsequent block won't increase the time penalty for this guy
    query = "DELETE FROM {} WHERE name = ?".format(secmon.expired_objects_table)
    secmon.sql_exec(query, src_addr)

    # find the secmonitor process
    proc = sngpy.Service.find_process()
    if proc is not None:
        # send a SIGHUP signal to wake up immediately
        proc.send_signal(signal.SIGHUP)
    secmon._logger.info("Forced unblock of IP {}".format(src_addr))
    return True


def fw_block_ip(src_addr, seconds):
    secmon = get_secmon()
    minutes = (seconds / 60)
    secmon._logger.warning("Blocking IP {} for about {} minutes ({} seconds exactly)".format(src_addr, minutes, seconds))

    secmon.sip_chain.append_rule(src_addr)

    if seconds > 0:
        secmon.sched.enter(seconds, 1, unblock_ip, (src_addr, ))


def sql_unblock_ip(src_addr):
    secmon = get_secmon()

    query = "DELETE FROM {} WHERE name = ?".format(secmon.blocked_objects_table)
    secmon.sql_exec(query, src_addr)

    query = """ INSERT INTO {} (name, last_expiration) VALUES (?, NOW())
                ON DUPLICATE KEY UPDATE last_expiration = NOW()
            """.format(secmon.expired_objects_table)
    secmon.sql_exec(query, src_addr)


def unblock_ip(src_addr):
    if fw_unblock_ip(src_addr):
        sql_unblock_ip(src_addr)


def fw_unblock_ip(src_addr):
    secmon = get_secmon()
    secmon._logger.warning("Unblocking IP {}".format(src_addr))

    try:
        secmon.sip_chain.delete_rule(src_addr)
        return True
    except (iptc.IPTCError, iptc.XTablesError):
        return False


def unblock_expired():
    secmon = get_secmon()
    now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    query = "SELECT * FROM {} WHERE block_expiration <= ?".format(secmon.blocked_objects_table)
    cursor = secmon.sql_exec(query, now)
    for obj in cursor:
        unblock_ip(obj.name)


def block_all():
    secmon = get_secmon()
    now = datetime.now()
    now_str = now.strftime('%Y-%m-%d %H:%M:%S')
    # Read from database which IPs are supposed to be blocked
    dbquery = "SELECT * FROM {} WHERE block_expiration > ?".format(secmon.blocked_objects_table)
    cursor = secmon.sql_exec(dbquery, now_str)

    # Block each object as needed
    for obj in cursor:
        blk_exp_str = str(obj.block_expiration)
        secmon._logger.debug("Re-blocking IP {}, block_expiration={}".format(obj.name, blk_exp_str))
        if blk_exp_str.startswith('9999'):
            fw_block_ip(obj.name, 0)
        else:
            xtime = datetime.strptime(blk_exp_str, '%Y-%m-%d %H:%M:%S')
            period = xtime - now
            fw_block_ip(obj.name, period.seconds)



class SIPSecRuleMatch(object):

    def __init__(self, name, rule, logger):
        self._name = name
        self._rule = rule
        self._logger = logger

        self.events = []
        self.src_ip = ''
        self.profile = ''
        self.account = ''
        self.sched_event = None

        self._logger.debug("New rule match '{}'".format(name))

    def __str__(self):
        return self._name


class SIPSecAction(object):
    def __init__(self, args, rule, logger):
        self._args = args
        self._rule = rule
        self._logger = logger

    def execute(self, rule_match):
        pass


# This is useful to block script-kiddies using scanners to find vulnerable accounts
class BlockIPAction(SIPSecAction):

    def __init__(self, args, rule, logger):
        super(BlockIPAction, self).__init__(args, rule, logger)
        self._block_time = int(args)  # This is in minutes
        self._logger.debug("Created action to block IP "
                           "for {} minutes".format(self._block_time))

    def execute(self, rule_match):
        perform_block(rule_match._rule._id,
                      rule_match.src_ip,
                      (self._block_time * 60))


class Filter(object):
    def __init__(self, expr, rule, logger):
        self._expr = expr
        self._rule = rule
        self._logger = logger

    def match(self, obj, *args):
        return False


class IPFilter(Filter):
    def __init__(self, expr, rule, logger):

        self._accepted_ips = {}
        self._negated_ips = {}
        self._accepted_networks = {}
        self._negated_networks = {}

        super(IPFilter, self).__init__(expr, rule, logger)

        self._logger.debug("Building IP filter from expr {}".format(expr))
        expr_elements = expr.split(',')
        for expr_str in expr_elements:
            ipstr = expr_str.strip()
            negated = False

            if ipstr[0] == '!':
                negated = True
                ipstr = ipstr[1:]

            if '/' in ipstr:
                net = ipaddr.IPNetwork(ipstr)
                if negated:
                    self._negated_networks[ipstr] = net
                else:
                    self._accepted_networks[ipstr] = net
            else:
                addr = ipaddr.IPAddress(ipstr)
                if negated:
                    self._negated_ips[ipstr] = addr
                else:
                    self._accepted_ips[ipstr] = addr

    def match(self, obj):
        match_status = False
        addr = None
        addr_str = None
        try:
            addr_str = str(obj)
            addr = ipaddr.IPAddress(addr_str)
        except ValueError:
            self._logger.error("Provided match object {} is not a valid IP address".format(obj))
            return False

        if addr_str in self._accepted_ips:
            return True

        if self._negated_ips and addr_str not in self._negated_ips:
            return True

        for k, net in self._negated_networks.items():
            if net.Contains(addr):
                # the address is contained in one of the
                # negated networks, therefore is not a
                # match
                # ie, a list of negated networks
                # (!192.168.1.0/24,!192.168.2.0/24)
                # means the intention is to match IPs that DO NOT
                # belong to ANY of those 2 networks
                match_status = False
                break
            else:
                match_status = True  # potential match

        for k, net in self._accepted_networks.items():
            if net.Contains(addr):
                return True

        return match_status


class UserAgentFilter(Filter):

    regex_id = 'regex='

    def __init__(self, expr, rule, logger):
        self.user_agent_str = None
        self.regex = None
        self.regex_str = ''

        super(UserAgentFilter, self).__init__(expr, rule, logger)

        self._logger.debug("Building user agent filter from expr {}".format(expr))

        expr.strip()
        if self.regex_id == expr[:len(self.regex_id)]:
            self.regex_str = expr[len(self.regex_id):]
            self.regex = re.compile(self.regex_str, re.IGNORECASE)
        else:
            self.user_agent_str = expr

    def match(self, obj):
        agent = str(obj)

        if self.regex is not None:
            return self.regex.match(agent) is not None

        if self.user_agent_str is not None and agent == self.user_agent_str:
            return True

        return False


class AccountFilter(Filter):

    regex_id = 'regex='

    def __init__(self, expr, rule, logger):
        self._accepted_accounts = {}
        self._negated_accounts = {}
        self._unknown = False
        self.regex = None
        self.regex_str = ''

        super(AccountFilter, self).__init__(expr, rule, logger)

        self._logger.debug("Building account filter from expr {}".format(expr))

        expr.strip()
        if expr == 'unknown':
            self._unknown = True
        elif self.regex_id == expr[:len(self.regex_id)]:
            self.regex_str = expr[len(self.regex_id):]
            self.regex = re.compile(self.regex_str, re.IGNORECASE)
        else:
            expr_elements = expr.split(',')
            for expr_str in expr_elements:
                account = expr_str.strip()
                if account[0] == '!':
                    account = account[1:]
                    self._negated_accounts[account] = account
                else:
                    self._accepted_accounts[account] = account

    def match(self, obj, invalid):
        account = str(obj)

        if self.regex is not None:
            return self.regex.match(account) is not None

        if self._unknown and invalid:
            return True

        if account in self._accepted_accounts:
            return True

        if self._negated_accounts and account not in self._negated_accounts:
            return True

        return False


class SIPSecRule(object):
    def __init__(self, dbreg, logger):
        self._logger = logger
        self._id = int(dbreg.rule_id)
        self._name = dbreg.name
        self._failed_attempts = dbreg.failed_attempts
        self._time_frame = (dbreg.time_frame * 60)

        if dbreg.profile_filter is None:
            self._profile_filter = ''
        else:
            self._profile_filter = dbreg.profile_filter

        self._actions = []
        self._ip_filter = None
        self._account_filter = None
        self._user_agent_filter = None
        self._matches = {}

        # build filters
        if dbreg.src_ip_filter_expr:
            self._ip_filter = IPFilter(dbreg.src_ip_filter_expr, self, self._logger)

        if dbreg.account_filter_expr:
            self._account_filter = AccountFilter(dbreg.account_filter_expr, self, self._logger)

        if dbreg.user_agent_filter_expr:
            self._user_agent_filter = UserAgentFilter(dbreg.user_agent_filter_expr, self, self._logger)

        # build actions
        self._build_actions(dbreg.action_expr)

    def _build_actions(self, expr):
        self._logger.debug("Building actions from expr {}".format(expr))
        action_list = expr.split(',')
        for action_str in action_list:
            action_str = action_str.strip()
            if not action_str:
                self._logger.debug("Ignoring empty action for rule '{}'".format(self))
                continue
            action_name, sep, action_args = action_str.partition('=')
            if not sep:
                self._logger.warning("Invalid action '{}' for rule '{}'".format(action_str, self))
                continue
            if action_name == 'block_ip':
                action = BlockIPAction(action_args, self, self._logger)
            else:
                self._logger.warning("Invalid action '{}' for rule '{}'".format(action_str, self))
                continue
            self._actions.append(action)

    def match(self, event):
        secmon = get_secmon()
        status = event.get_header('auth-status')

        # Only match failure attempts
        if status and status == 'ok':
            return False

        hash = "{}/{}".format(self._id, self._name)
        profile = event.get_header('profile-name')
        src_ip = event.get_header('network-ip')
        account = event.get_header('username') + "@" + event.get_header('realm')
        user_agent = event.get_header('user-agent')
        reason = event.get_header('fail-reason')
        invalid_user = False
        if reason in ('invalid-user', 'domain-not-bound'):
            invalid_user = True

        # Match profile
        if self._profile_filter and profile != self._profile_filter:
            # self._logger.debug("Profile {} did not match".format(profile))
            return False

        # Match IP
        if self._ip_filter and not self._ip_filter.match(src_ip):
            # self._logger.debug("IP {} did not match".format(src_ip))
            return False

        # Match account
        if self._account_filter and not self._account_filter.match(account, invalid_user):
            # self._logger.debug("Account {} did not match".format(account))
            return False

        # Match user agent
        if user_agent is not None and self._user_agent_filter is not None \
                and not self._user_agent_filter.match(user_agent):
            # self._logger.debug("User agent {} did not match".format(user_agent))
            return False

        if self._profile_filter:
            hash += '/' + profile

        # Always include src ip in the hash
        hash += '/' + src_ip

        if self._account_filter:
            if self._account_filter.regex is not None:
                hash += '/' + self._account_filter.regex_str
            elif invalid_user:
                hash += '/' + '@invalid-user@'
            else:
                hash += '/' + account

        if hash in self._matches:
            rule_match = self._matches[hash]
            # cancel previous safety timer
            secmon.sched.cancel(rule_match.sched_event)
        else:
            rule_match = SIPSecRuleMatch(hash, self, self._logger)
            rule_match.profile = profile
            rule_match.src_ip = src_ip
            rule_match.account = account
            self._matches[hash] = rule_match

        # Save event and its capture time
        curtime = time.time()
        event.secmon_time = curtime
        rule_match.events.append(event)

        # schedule an expiration timer for this match in case no events match anymore within the time frame
        rule_match.sched_event = secmon.sched.enter(self._time_frame, 1, self._delete_match, (rule_match, ))

        # clear up any events that are no longer valid considering the time frame
        expired_events = []
        for regev in rule_match.events:
            diff = (curtime - regev.secmon_time)
            if diff < 0:
                # Attempt to fix obviously broken time (time on system changed)
                regev.secmon_time = curtime
            if diff > self._time_frame:
                expired_events.append(regev)
        for expev in expired_events:
            rule_match.events.remove(expev)

        # if not enough attempts have occurred within the valid time frame, we're done here
        if len(rule_match.events) < self._failed_attempts:
            return True

        # Enough events collected now, trigger the actions
        self._logger.info("Executing actions for rule match {} from IP {}".format(hash, src_ip))
        for action in self._actions:
            action.execute(rule_match)

        # get rid of the match
        secmon.sched.cancel(rule_match.sched_event)
        self._delete_match(rule_match)

        return True

    def _delete_match(self, match):
        hash = str(match)
        self._logger.debug("Deleting rule match '{}'".format(hash))
        del self._matches[hash]

    def __str__(self):
        return self._name


class SIPSecMonitor(sngpy.DBService):

    _service_name = 'sipsecmon'

    def __init__(self, logger):

        self._dbconf = {}
        self._dbconf_params = ['connection-string', 'table-prefix']
        self._dbconn = None

        self._swconf = {}
        self._swconf_params = ['connection-string']

        self._sleep_interval = 1

        self.blocked_objects_table = ''
        self.expired_objects_table = ''
        self.expiration_cache_period = 60
        self.rules_table = ''

        self.iptc_max_retry = 100
        self.iptc_retry_sleep = 0.01

        self._rules = []

        self._oswc_conn = None

        self.sip_chain = None
        self.chain = { 4: 'sip_security', 6: 'sip_security6' }

        self._logger = logger

        formatter = logging.Formatter(self._logformat)

        console_handler = logging.StreamHandler()
        console_handler.setFormatter(formatter)
        self._logger.addHandler(console_handler)

        self.sched = sngpy.Scheduler(time.time, time.sleep)

        self._null = open(os.devnull, 'w')

        # custom SIPSecMon options should be added here via parser.add_option()
        self.parser = OptionParser()
        self.parser.add_option("", "--unblock-ip", dest="unblock_ip",
                               help="Unblock the given IP address")
        self.parser.add_option("", "--init-database", action="store_true",
                               dest="init_database",
                               help="Initialize Database")

        super(SIPSecMonitor, self).__init__()

        # handle sighup to reload rules and block/unblock objects
        self.reload = False

        def sigreload(signum, frame):
            secmon = get_secmon()
            secmon.reload = True
        signal.signal(signal.SIGHUP, sigreload)

    def _parse_db(self, params):
        self._logger.debug("Reading database parameters")
        for p in params:
            self._logger.debug("database: {}={}".format(p.attrib['name'], p.attrib['value']))
            if p.attrib['name'] in self._dbconf_params:
                self._dbconf[p.attrib['name']] = p.attrib['value']
            else:
                raise ValueError("Unknown database XML parameter {}".format(p.attrib['name']))
        if 'connection-string' not in self._dbconf:
                raise ValueError("Missing database connection-string XML parameter")
        if 'table-prefix' not in self._dbconf:
                self._dbconf['table-prefix'] = ''
        self.rules_table = self._dbconf['table-prefix'] + "rules"
        self.blocked_objects_table = self._dbconf['table-prefix'] + "blocked_objects"
        self.expired_objects_table = self._dbconf['table-prefix'] + "expired_objects"

    def _parse_switch(self, params):
        self._logger.debug("Reading switch parameters")
        for p in params:
            self._logger.debug("switch: {}={}".format(p.attrib['name'], p.attrib['value']))
            if p.attrib['name'] in self._swconf_params:
                self._swconf[p.attrib['name']] = p.attrib['value']
            else:
                raise ValueError("Unknown switch XML parameter {}".format(p.attrib['name']))
        if 'connection-string' not in self._swconf:
                raise ValueError("Missing switch connection-string XML parameter")

    def configure(self):
        ret = True
        try:
            tree = super(SIPSecMonitor, self).configure()

            dbconf = tree.find('database')
            if dbconf is None:
                raise ValueError("Missing <database> configuration")
            else:
                params = self._get_params(dbconf)
                self._parse_db(params)

            swconf = tree.find('switch')
            if swconf is None:
                raise ValueError("Missing <switch> configuration")
            else:
                params = self._get_params(swconf)
                self._parse_switch(params)

        except:
            self._logger.critical("Failed to configure service")
            exc_type, exc_value, exc_traceback = sys.exc_info()
            self._print_exception(exc_type, exc_value, exc_traceback)
            ret = False

        return ret

    def _db_init(self):
        try:
            # create rules table
            dbquery = "CREATE TABLE IF NOT EXISTS `{}` ("  """
                      `rule_id` INT UNSIGNED NOT NULL AUTO_INCREMENT ,
                      `name` VARCHAR(255) NULL ,
                      `failed_attempts` INT UNSIGNED NOT NULL ,
                      `time_frame` INT UNSIGNED NOT NULL ,
                      `profile_filter` TEXT NULL ,
                      `src_ip_filter_expr` TEXT NULL ,
                      `account_filter_expr` TEXT NULL ,
                      `user_agent_filter_expr` TEXT NULL ,
                      `action_expr` TEXT NULL ,
                      `comments` TEXT NULL ,
                      PRIMARY KEY (`rule_id`) ,
                      UNIQUE INDEX `rule_id_UNIQUE` (`rule_id` ASC) )
                   ENGINE = InnoDB;
                   """.format(self.rules_table)

            self.sql_exec(dbquery)

            # create blocked objects table
            dbquery = "CREATE TABLE IF NOT EXISTS `{}` (" """
                      `blocked_object_id` INT UNSIGNED NOT NULL AUTO_INCREMENT ,
                      `rule_id` INT UNSIGNED NOT NULL ,
                      `name` VARCHAR(255) NOT NULL ,
                      `type` VARCHAR(10) NOT NULL ,
                      `block_time` DATETIME NOT NULL ,
                      `block_expiration` DATETIME NOT NULL ,
                      PRIMARY KEY (`blocked_object_id`) ,
                      UNIQUE INDEX `blocked_object_id_UNIQUE` (`blocked_object_id` ASC) ,
                      INDEX `rule_id` (`rule_id` ASC))
                   ENGINE = InnoDB;
                   """.format(self.blocked_objects_table)
            self.sql_exec(dbquery)

            # create expired objects table
            dbquery = "CREATE TABLE IF NOT EXISTS `{}` (" """
                      `name` VARCHAR(255) NOT NULL ,
                      `block_count` INT UNSIGNED NOT NULL DEFAULT 1 ,
                      `last_expiration` DATETIME NOT NULL ,
                      PRIMARY KEY (`name`))
                   ENGINE = InnoDB;
                   """.format(self.expired_objects_table)
            self.sql_exec(dbquery)
        except:
            self._logger.error("Failed to initialize database")
            raise

    def _load_rules(self):
        try:
            dbquery = "SELECT * FROM {}".format(self.rules_table)
            cursor = self.sql_exec(dbquery)
            rulecnt = 0
            for rule_reg in cursor:
                rule = SIPSecRule(rule_reg, self._logger)
                self._rules.append(rule)
                rulecnt = rulecnt + 1
            self._logger.info("Loaded {} rules".format(rulecnt))
        except:
            self._logger.error("Failed to load rules")
            raise

    def cmd_exec(self, cmd, errcheck=True):
        self._logger.debug("Executing command {}".format(' '.join(cmd)))
        rc = subprocess.call(cmd, stdout=self._null, stderr=self._null)
        if errcheck and rc:
            self._logger.error("Command {} returned {}".format(' '.join(cmd), rc))
        return rc

    def _setup_blocked_objects(self):
        try:
            self._logger.info("Setting up blocked objects")
            self.sip_chain = SIPSecmonChain(self.chain, self._logger)

            # Make sure all blocked obj are still blocked and
            # unblock the expired ones
            block_all()
            unblock_expired()
        except (iptc.IPTCError, iptc.XTablesError):
            self._logger.error("Failed to setup blocked objects")
            raise

    def unblock_ip(self, ipaddr):
        self._db_connect()
        return unblock_ip_ex(ipaddr)

    def init_database(self):
        # connection to the database takes care of initialization
        try:
            self._db_connect()
            return True
        except:
            return False

    def _housekeeping(self):
        self.sched.fast_run()
        if self.reload:
            self._reload()

    def _connect(self):
        # Connection loop (wait until we can connect or the daemon is stopped)
        self._oswc_conn = None
        connection_string = self._swconf['connection-string']
        while self.daemon_alive:
            self._housekeeping()
            # connect to the event provider
            if self._oswc_conn is None:
                self._oswc_conn = oswc.create_connection(connection_string, logger=self._logger)
                if self._oswc_conn is None:
                    self._logger.error("Failed to create connection to {}".format(connection_string))
                    time.sleep(self._sleep_interval)
                    continue
                listener = RegistrationListener(self._rules, self._logger)
                listener.set_filter(['REGISTER_ATTEMPT'])
                self._oswc_conn.add_event_listener(listener)

                listener = SIPLimitsListener(self._logger)
                listener.set_filter(['SIP_LIMIT_EXCEEDED'])
                self._oswc_conn.add_event_listener(listener)

            if not self._oswc_conn.connect():
                self._logger.debug("Failed to connect to {}".format(connection_string))
                time.sleep(self._sleep_interval)
                continue

            self._logger.info("Connected to {}".format(connection_string))
            break

    def _reload(self):
        self.reload = False
        self._logger.info("Forcing unblock of expired objects")

        # Ideally we should also remove scheduled entries otherwise the
        # scheduled callback will trigger and will attempt to unblock an IP
        # which is already unblocked, or worst it will unblock an IP entry that
        # a different rule triggered.  This is not a major issue though since
        # typically if a given IP was unblocked manually it is because it is
        # trusted and the fact that was blocked again is just a rule
        # mis-configuration or something alike

        unblock_expired()

    def _db_connect(self):
        self._dbconn = None
        try:
            self._dbconn = pyodbc.connect(self._dbconf['connection-string'])
        except Exception as e:
            self._logger.error("Failed to connect to the database {}: {}".format(self._dbconf['connection-string'], e))
            raise

        # initialize the database
        self._db_init()

    def run(self):

        super(SIPSecMonitor, self).run()

        # connect to the database
        self._db_connect()

        # Load the rules
        self._load_rules()

        # Verify blocked objects are still blocked and unblock any expired registers
        self._setup_blocked_objects()

        # Connect to the event source
        self._connect()

        self._logger.info("{} is now running".format(self._service_name))

        # Event loop
        while self.daemon_alive:
            try:
                self._housekeeping()
                sched_sleep = self.sched.next_event_time_delta()
                if sched_sleep is 0:
                    """
                    You'd think we could run a housekeeping
                    if sched_sleep is zero, but apparently
                    sometimes time goes back and we don't
                    want to call receive_event with 0 as
                    that seems to (oddly) block, just wait
                    one more second to be on the safe side
                    """
                    sched_sleep = 1

                e = self._oswc_conn.receive_event(timeout=sched_sleep * 1000)
                if e is not None and str(e) == 'OS_DISCONNECTED':
                    self._logger.info("Server connection lost")
                    self._connect()

            except KeyboardInterrupt:
                self._logger.info("Stopping {}. User aborted.".format(self._service_name))
                break
            except:
                exc_type, exc_value, exc_traceback = sys.exc_info()
                self._print_exception(exc_type, exc_value, exc_traceback)
                time.sleep(self._sleep_interval)

        self._logger.info("{} is now terminating".format(self._service_name))


class RegistrationListener(oswc.EventListener):

    def __init__(self, rules, logger):
        self._rules = rules
        self._logger = logger
        super(oswc.EventListener, self).__init__()

    def on_event(self, event):
        self._logger.debug("Received registration event {}".format(event))
        for rule in self._rules:
            rule.match(event)


class SIPLimitsListener(oswc.EventListener):

    def __init__(self, logger):
        self._logger = logger
        super(oswc.EventListener, self).__init__()

    def on_event(self, event):
        """
        Someone exceeded a limit, we'll block their IP
        for the limit seconds to get them back inline
        with the required limits. Ideally we would block
        for only the remaining time of the limit period, but
        for now we block for a full period
        """
        self._logger.warning(
            "Received sip limits event from profile '{}' for sip agent '{}' from '{}' (usage exceeded {} requests per {} seconds, method = {}, host = {})".format(
                event.get_header('profile-name'),
                event.get_header('user-agent'), event.get_header('network-ip'),
                event.get_header('limit-max'), event.get_header('limit-seconds'),
                event.get_header('sip-method'), event.get_header('limit-host')
            )
        )
        perform_block(0, event.get_header('network-ip'),
                      int(event.get_header('limit-seconds')))


logger = logging.getLogger(SIPSecMonitor._service_name)
logger.setLevel(logging.DEBUG)


def get_secmon():
    return secmon

secmon = None
try:
    secmon = SIPSecMonitor(logger)

    if secmon.options.stop is not None:
        secmon.stop()
        sys.exit(0)

    if secmon.options.restart is not None:
        secmon.restart()
        sys.exit(0)

    if secmon.options.conf_path is None:
        # convenient way to retrieve -c option when service is running
        secmon.options.conf_path = sngpy.Service.find_conf_path()

    if secmon.options.conf_path is None:
        secmon.parser.print_help()
        secmon.parser.error("-c is required to find the configuration path")
        sys.exit(1)

    if secmon.configure() is False:
        secmon.parser.error("Failed to configure daemon using file {}".format(secmon.options.conf_path))
        sys.exit(1)

    if secmon.options.unblock_ip:
        if secmon.unblock_ip(secmon.options.unblock_ip):
            sys.exit(0)
        else:
            sys.exit(1)

    if secmon.options.init_database:
        if secmon.init_database():
            sys.exit(0)
        else:
            sys.exit(1)

    # Following options (either start or run) should not be executed if the pid file exists
    if os.path.exists(secmon.pidfile):
        logger.error("Service seems to be running already, pid file {} already exists".format(secmon.pidfile))
        sys.exit(1)

    # Decide whether to run in the background (Daemon mode) or foreground
    if secmon.options.start is not None:
        secmon.start()
    else:
        secmon.run()

    sys.exit(0)
except KeyboardInterrupt:
    logger.info("Received keyboard interrupt, exiting.")
    sys.exit(0)
except SystemExit as e:
    # We just catch this so it won't end up in the catch-all below but we still
    # must raise the exception if we want python to exit with the provided
    # sys.exit() return code
    if not sys.stdin.isatty():
        logger.info("SystemExit status {}".format(e.code))
    raise
except:
    exc_type, exc_value, exc_traceback = sys.exc_info()
    sngpy.print_exception(exc_type, exc_value, exc_traceback)
    if not sys.stdin.isatty():
        logger.error("Unexpected exception {}/{}, aborting.".format(exc_type, exc_value))
    sys.exit(1)