#!/usr/local/nsc/bin/python
# vim: tabstop=4 softtabstop=4 shiftwidth=4 textwidth=80 smarttab expandtab
"""
* Copyright (C) 2012  Sangoma Technologies Corp.
* All Rights Reserved.
*
* Author(s)
* David Yat Sin <dyatsin@sangoma.com>
*
* This code is Sangoma Technologies Confidential Property.
* Use of and access to this code is covered by a previously executed
* non-disclosure agreement between Sangoma Technologies and the Recipient.
* This code is being supplied for evaluation purposes only and is not to be
* used for any other purpose.
"""

import sys
import os
import time
import logging
import re
import oswc
import sngpy
import pyodbc
import subprocess
import signal
import iptc
import netifaces as ni
import errno
import ctypes
from optparse import OptionParser

class DBAction(object):
    def __init__(self):
        self._logger.debug("Creation DBAction")

    @staticmethod
    def insert(session_id, source_ip, destination_ip, forward_ip, rtp_port):
        mediamon = get_mediamon()
        query = """ INSERT into %s (session_id, source_ip, destination_ip, forward_ip, rtp_port) VALUES (?, ?, ?, ?, ?)
                """ % (mediamon.sessions_table)

        mediamon.sql_exec(query, session_id, source_ip, destination_ip, forward_ip, rtp_port)

    @staticmethod
    def delete(session_id):
        mediamon = get_mediamon()
        query = "DELETE from %s WHERE session_id = '%s' " % (mediamon.sessions_table, session_id)
        mediamon.sql_exec(query)


    @staticmethod
    def update(session_id, source_ip, destination_ip, forward_ip, rtp_port):
        mediamon = get_mediamon()
        query = "UPDATE %s SET source_ip = '%s', destination_ip = '%s', forward_ip = '%s', rtp_port = '%d' WHERE session_id = '%s' " % (mediamon.sessions_table, source_ip, destination_ip, forward_ip, rtp_port, session_id)
        mediamon.sql_exec(query)


class IPTableAction(object):
    def __init__(self, mediamon):
        self._table_nat = iptc.Table("nat", False)
        self._table_filter = iptc.Table("filter", False)

        self._mediamon = mediamon
        self._logger = mediamon._logger
        self._prerouting_chain = iptc.Chain(self._table_nat, mediamon.prerouting_chain_name)
        self._postrouting_chain = iptc.Chain(self._table_nat, mediamon.postrouting_chain_name)
        self._forward_chain = iptc.Chain(self._table_filter, mediamon.forward_chain_name)

        self.num_subchains = 0
        self.subchain_size = 0

        #hashtables of subchains table indexed by their first port
        self._pre_subchains = dict()
        self._post_subchains = dict()

    def insert_chain(self, table, chain):
        #creates a new chain within table if it does not exist - do not retry
        for ch in table.chains:
            if ch.name == chain.name:
                self._logger.debug("Chain %s already exists" % chain.name)
                return

        try:
            table.create_chain(chain)
        except (iptc.IPTCError, iptc.XTablesError), e:
            if ctypes.get_errno() != errno.EAGAIN:
                self._logger.error("Failed to insert chain %s in table %s: %s" % (chain.name, table.name, str(e)))
            else:
                self._logger.debug("Kernel consistency check failed at insert_chain..")
            raise

    def setup_chains(self):
        while True:
            try:
                self._table_nat.restart()

                self.insert_chain(self._table_nat, self._prerouting_chain)
                input_chain = iptc.Chain(self._table_nat, "PREROUTING")

                rule = iptc.Rule()
                target = iptc.Target(rule, self._mediamon.prerouting_chain_name)
                rule.target = target
                input_chain.insert_rule(rule)

                self._table_nat.close() # refresh()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed setting-up prerouting chain: %s" % (str(e)))
                    break
                else:
                    self._logger.debug("Kernel consistency check failed at setup_chains PREROUTING, retrying..")
                    time.sleep(0.05)


        while True:
            try:
                self._table_nat.restart()

                self.insert_chain(self._table_nat, self._postrouting_chain)
                input_chain = iptc.Chain(self._table_nat, "POSTROUTING")

                rule = iptc.Rule()
                target = iptc.Target(rule, self._mediamon.postrouting_chain_name)
                rule.target = target
                input_chain.insert_rule(rule)

                self._table_nat.close() # refresh()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed setting-up postrouting chain error:%s" % (str(e)))
                    break
                else:
                    self._logger.debug("Kernel consistency check failed at setup_chains POSTROUTING, retrying..")
                    time.sleep(0.05)

        while True:
            try:
                self._table_filter.restart()

                if self._table_filter.is_chain(self._forward_chain) is False:
                    self._table_filter.create_chain(self._forward_chain)

                input_chain = iptc.Chain(self._table_filter, "FORWARD")
                rule = iptc.Rule()
                target = iptc.Target(rule, self._mediamon.forward_chain_name)
                rule.target = target
                input_chain.insert_rule(rule)

                self._table_filter.close() # refresh()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed setting-up forward chain: %s" % (str(e)))
                    break
                else:
                    self._logger.debug("Kernel consistency check failed at setup_chains FORWARD, retrying..")
                    time.sleep(0.05)

    def flush_chains(self, log_error=True):
        while True:
            try:
                self._table_nat.restart()
                self._table_filter.restart()

                self._prerouting_chain.flush()
                self._postrouting_chain.flush()

                self._forward_chain.flush()

                self._table_nat.close()
                self._table_filter.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    if log_error:
                        self._logger.error("Failed flushing chains: %s" % (str(e)))
                    else:
                        self._logger.info("Failed flushing chains: %s" % (str(e)))
                    break # pass
                else:
                    self._logger.debug("Kernel consistency check failed at flush_chains, retrying..")
                    time.sleep(0.05)


    def remove_references(self, chain, target):
        # do not retry
        try:
            for r in chain.rules:
                if (str(r.target.name) == target):
                    chain.delete_rule(r)

        except (iptc.IPTCError, iptc.XTablesError), e:
            if ctypes.get_errno() != errno.EAGAIN:
                self._logger.error("Failed to references error:%s" % (str(e)))
            else:
                self._logger.debug("Kernel consistency check failed at remove_references..")
            pass

    def remove_chains(self):
        while True:
            try:
                self._table_nat.restart()
                self._table_filter.restart()

                self.remove_references(iptc.Chain(self._table_nat, 'PREROUTING'), self._mediamon.prerouting_chain_name)
                self.remove_references(iptc.Chain(self._table_nat, 'POSTROUTING'), self._mediamon.postrouting_chain_name)
                self.remove_references(iptc.Chain(self._table_filter, 'FORWARD'), self._mediamon.forward_chain_name)

                self.delete_chain(self._table_nat, self._postrouting_chain)
                self.delete_chain(self._table_nat, self._prerouting_chain)

                self.delete_chain(self._table_filter, self._forward_chain)

                self._table_nat.close()
                self._table_filter.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    break
                else:
                    self._logger.debug("Kernel consistency check failed at remove_chains, retrying..")
                    time.sleep(0.05)

    def _get_pre_subchain(self, port):
        for port_start, subchain in self._pre_subchains.iteritems():
            if port < port_start:
                continue
            elif port > (port_start + self.subchain_size):
                continue
            else:
                return subchain

        return None

    def _get_post_subchain(self, port):
        for port_start, subchain in self._post_subchains.iteritems():
            if port < port_start:
                continue
            elif port > (port_start + self.subchain_size):
                continue
            else:
                return subchain

        return None

    def create_pre_subchain(self, subchain_port_range_start, subchain_port_range_stop):
        subchain_name = "mediamon-pre-%d-%d" %(subchain_port_range_start , subchain_port_range_stop)
        self._logger.debug("Creating subchain:%s" % subchain_name)

        while True:
            try:
                iptc.TABLE_NAT.restart()

                subchain = iptc.Chain(iptc.TABLE_NAT, subchain_name)
                iptc.TABLE_NAT.create_chain(subchain)

                iptc.TABLE_NAT.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed setting-up %s chain: %s" % (subchain_name, str(e)))
                    break
                else:
                    self._logger.debug("Kernel consistency check failed at create_pre_subchain/chain, retrying..")
                    time.sleep(0.05)

        self._pre_subchains[subchain_port_range_start] = subchain

        while True:
            try:
                self._table_nat.restart()

                rule = iptc.Rule()
                rule.protocol = "udp"
                match = iptc.Match(rule, "udp")
                match.dport = "%d:%d" %(subchain_port_range_start, subchain_port_range_stop)
                rule.add_match(match)
                rule.target = iptc.Target(rule, subchain_name)
                self._prerouting_chain.insert_rule(rule)

                self._table_nat.close()
                break
            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed setting-up rule %s:%s" % (subchain_name, str(e)))
                    break
                else:
                    self._logger.debug("Kernel consistency check failed at create_pre_subchain/rule, retrying..")
                    time.sleep(0.05)


    def create_post_subchain(self, subchain_port_range_start, subchain_port_range_stop):
        subchain_name = "mediamon-post-%d-%d" %(subchain_port_range_start , subchain_port_range_stop)
        self._logger.debug("Creating subchain:%s" % subchain_name)

        while True:
            try:
                iptc.TABLE_NAT.restart()

                subchain = iptc.Chain(iptc.TABLE_NAT, subchain_name)
                iptc.TABLE_NAT.create_chain(subchain)

                iptc.TABLE_NAT.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed setting-up %s chain: %s" % (subchain_name, str(e)))
                    break
                else:
                    self._logger.debug("Kernel consistency check failed at create_post_subchain/chain, retrying..")
                    time.sleep(0.05)

        self._post_subchains[subchain_port_range_start] = subchain

        while True:
            try:
                self._table_nat.restart()

                rule = iptc.Rule()
                rule.protocol = "udp"
                match = iptc.Match(rule, "udp")
                match.sport = "%d:%d" %(subchain_port_range_start, subchain_port_range_stop)
                rule.add_match(match)
                rule.target = iptc.Target(rule, subchain_name)
                self._postrouting_chain.insert_rule(rule)

                self._table_nat.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed setting-up rule %s:%s" % (subchain_name, str(e)))
                    break
                else:
                    self._logger.debug("Kernel consistency check failed at create_post_subchain/rule, retrying..")
                    time.sleep(0.05)

    def delete_chain(self, table, chain):
        # do not retry here
        for r in chain.rules:
            try:
                chain.delete_rule(r)

            except (iptc.IPTCError, iptc.XTablesError), e:
                self._logger.error("Failed to delete rules:%s" % (str(e)))
                pass

        table.delete_chain(chain);

    def remove_subchains(self):
        while True:
            try:
                iptc.TABLE_NAT.restart()

                for key, subchain in self._pre_subchains.iteritems():
                        self.delete_chain(iptc.TABLE_NAT, subchain)
                for key, subchain in self._post_subchains.iteritems():
                        self.delete_chain(iptc.TABLE_NAT, subchain);

                iptc.TABLE_NAT.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    break
                else:
                    self._logger.debug("Kernel consistency check failed at remove_subchains, retrying..")
                    time.sleep(0.05)

    def enable_forward(self, port_range_start, port_range_stop):
        while True:
            try:
                self._table_filter.restart()

                rule = iptc.Rule()
                rule.protocol = "udp"
                match = iptc.Match(rule, "udp")
                match.dport = "%d:%d" %(port_range_start, port_range_stop)
                rule.add_match(match)
                rule.target = iptc.Target(rule, "ACCEPT")
                self._forward_chain.insert_rule(rule)

                rule = iptc.Rule()
                rule.protocol = "udp"
                match = iptc.Match(rule, "udp")
                match.sport = "%d:%d" %(port_range_start, port_range_stop)
                rule.add_match(match)
                rule.target = iptc.Target(rule, "ACCEPT")
                self._forward_chain.insert_rule(rule)

                self._table_filter.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed to enable forward accept rule %s" % (str(e)))
                    break # pass
                else:
                    self._logger.debug("Kernel consistency check failed at enable_forward, retrying..")
                    time.sleep(0.05)

    def enable_outbound_drop(self, ip_address):
        while True:
            try:
                self._table_nat.restart()

                rule = iptc.Rule()
                rule.src = ip_address
                rule.protocol = "udp"
                rule.target = iptc.Target(rule, "DROP")
                self._prerouting_chain.insert_rule(rule)

                self._table_nat.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed to enable outbound DROP rule %s" % (str(e)))
                    break # pass
                else:
                    self._logger.debug("Kernel consistency check failed at enable_outbound_drop, retrying..")
                    time.sleep(0.05)

    def enable_inbound_drop(self, netif, port_range_start, port_range_stop):
        while True:
            try:
                self._table_nat.restart()

                rule = iptc.Rule()
                rule.in_interface = netif
                rule.protocol = "udp"
                match = iptc.Match(rule, "udp")
                match.dport = "%d:%d" %(port_range_start, port_range_stop)
                rule.add_match(match)
                rule.target = iptc.Target(rule, "DROP")
                self._prerouting_chain.insert_rule(rule)

                self._table_nat.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed to enable inbound DROP rule %s" % (str(e)))
                    break # pass
                else:
                    self._logger.debug("Kernel consistency check failed at enable_inbound_drop, retrying..")
                    time.sleep(0.05)

    def enable_outbound_snat(self, netif, ip_address):
        while True:
            try:
                self._table_nat.restart()

                rule = iptc.Rule()
                rule.out_interface = netif
                rule.protocol = "udp"
                rule.src = ip_address

                if ni.AF_INET in ni.ifaddresses(netif):
                    netif_addr = ni.ifaddresses(netif)[ni.AF_INET][0]['addr']
                else:
                    return

                target = iptc.Target(rule, "SNAT")
                if self._mediamon.options.global_src_ip:
                    target.to_source = self._mediamon.options.global_src_ip
                else:
                    target.to_source = netif_addr
                rule.target = target

                self._postrouting_chain.insert_rule(rule)

                self._table_nat.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed to enable outbound SNAT rule: %s" % (str(e)))
                    break # pass
                else:
                    self._logger.debug("Kernel consistency check failed at enable_outbound_snat, retrying..")
                    time.sleep(0.05)

    def enable_inbound_dnat(self, netif, ip_address, port_range_start, port_range_stop):
        while True:
            try:
                if ni.AF_INET not in ni.ifaddresses(netif):
                    return

                self._table_nat.restart()

                rule = iptc.Rule()
                rule.in_interface = netif
                rule.protocol = "udp"

                match = iptc.Match(rule, "udp")
                match.dport = "%d:%d" % (port_range_start, port_range_stop)
                rule.add_match(match)

                target = iptc.Target(rule, "DNAT")
                target.to_destination = ip_address
                rule.target = target

                self._prerouting_chain.insert_rule(rule)

                self._table_nat.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed to enable inbound DNAT rule: %s" % (str(e)))
                    break # pass
                else:
                    self._logger.debug("Kernel consistency check failed at enable_inbound_dnat, retrying..")
                    time.sleep(0.05)

    def enable_loop_dnat(self, netif, external_netif, ip_address, port_range_start, port_range_stop):
        while True:
            try:
                if ni.AF_INET not in ni.ifaddresses(netif):
                    return

                if ni.AF_INET not in ni.ifaddresses(external_netif):
                    return

                self._table_nat.restart()

                rule = iptc.Rule()
                rule.in_interface = netif
                rule.protocol = "udp"
                rule.dst = ni.ifaddresses(external_netif)[ni.AF_INET][0]['addr']

                match = iptc.Match(rule, "udp")
                match.dport = "%d:%d" % (port_range_start, port_range_stop)
                rule.add_match(match)

                target = iptc.Target(rule, "DNAT")
                target.to_destination = ip_address
                rule.target = target

                self._prerouting_chain.insert_rule(rule)

                self._table_nat.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed to enable loop DNAT rule: %s" % (str(e)))
                    break # pass
                else:
                    self._logger.debug("Kernel consistency check failed at enable_loop_dnat, retrying..")
                    time.sleep(0.05)

    def _create_accept(self, forward_ip, port):
        rule = iptc.Rule()
        rule.protocol = "udp"
        rule.src = forward_ip
        match = iptc.Match(rule, "udp")
        match.sport = "%d:%d" % (port, port + 1)
        rule.add_match(match)

        target = iptc.Target(rule, "ACCEPT")
        rule.target = target
        return rule

    def _disable_accept(self, forward_ip, port):
        while True:
            try:
                self._prerouting_chain.table.restart()

                rule = self._create_accept(forward_ip, port)
                self._prerouting_chain.delete_rule(rule)

                self._prerouting_chain.table.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed to enable ACCEPT ip:%s port:%d err:%s" % (forward_ip, port, str(e)))
                    raise
                else:
                    self._logger.debug("Kernel consistency check failed at disable_accept, retrying..")
                    time.sleep(0.05)

    def _enable_accept(self, forward_ip, port):
        while True:
            try:
                self._prerouting_chain.table.restart()

                rule = self._create_accept(forward_ip, port)
                self._prerouting_chain.insert_rule(rule)

                self._prerouting_chain.table.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed to disable ACCEPT ip:%s port:%d err:%s" % (forward_ip, port, str(e)))
                    raise
                else:
                    self._logger.debug("Kernel consistency check failed at enable_accept, retrying..")
                    time.sleep(0.05)

    def _create_snat(self, destination_ip, port):
        rule = iptc.Rule()
        rule.protocol = "udp"
        match = iptc.Match(rule, "udp")
        match.sport = "%d:%d" % (port, port)
        rule.add_match(match)

        target = iptc.Target(rule, "SNAT")
        target.to_source = "%s:%d-%d" % (destination_ip, port, port)
        rule.target = target
        return rule

    def _enable_snat(self, destination_ip, port):
        while True:
            try:
                subchain = self._get_post_subchain(port)
                subchain.table.restart()

                rule_rtp = self._create_snat(destination_ip, port)
                rule_rtcp = self._create_snat(destination_ip, port + 1)

                if subchain is None:
                    self._logger.error("Failed to get subchain for port %d" % port)
                    break

                subchain.insert_rule(rule_rtp)
                subchain.insert_rule(rule_rtcp)

                subchain.table.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed to enable SNAT ip:%s port:%d err:%s" % (destination_ip, port, str(e)))
                    raise
                else:
                    self._logger.debug("Kernel consistency check failed at enable_snat, retrying..")
                    time.sleep(0.05)

    def _disable_snat(self, destination_ip, port):
        while True:
            try:
                subchain = self._get_post_subchain(port)
                subchain.table.restart()

                rule_rtp = self._create_snat(destination_ip, port)
                rule_rtcp = self._create_snat(destination_ip, port + 1)

                if subchain is None:
                    self._logger.error("Failed to get subchain for port %d" % port)
                    break

                subchain.delete_rule(rule_rtp)
                subchain.delete_rule(rule_rtcp)

                subchain.table.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed to disable SNAT ip:%s port:%d err:%s" % (destination_ip, port, str(e)))
                    raise
                else:
                    self._logger.debug("Kernel consistency check failed at disable_snat, retrying..")
                    time.sleep(0.05)

    def _create_dnat(self, destination_ip, forward_ip, port):
        rule = iptc.Rule()
        rule.protocol = "udp"
        """
        When the host is behind NAT and we have not learnt
        the host IP, we cannot use source_ip yet because
        it could be a natted IP and currently we do not
        receive events from vocallo when the real source
        UDP IP is detected and have no ability to reset the iptables
        rule

        Furthermore, if the intention is to protect against possible
        media corruption by someone eavesdropping our SIP packets, if they
        know the destination port, they also know the source IP and could
        easily forge the src IP and we would be screwed anyways, so no
        much value-added by filtering by source IP
        """
        rule.dst = destination_ip
        match = iptc.Match(rule, "udp")
        match.dport = "%d:%d" % (port, port)
        rule.add_match(match)

        target = iptc.Target(rule, "DNAT")
        target.to_destination = "%s:%d-%d" % (forward_ip, port, port)
        rule.target = target
        return rule

    def _enable_dnat(self, destination_ip, forward_ip, port):
        while True:
            try:
                subchain = self._get_pre_subchain(port)
                subchain.table.restart()

                rule_rtp = self._create_dnat(destination_ip, forward_ip, port)
                rule_rtcp = self._create_dnat(destination_ip, forward_ip, port + 1)

                if subchain is None:
                    self._logger.error("Failed to get subchain for port %d" % port)
                    break

                subchain.insert_rule(rule_rtp);
                subchain.insert_rule(rule_rtcp);

                subchain.table.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed to enable DNAT destination_ip:%s forward_ip:%s port:%d err:%s" % (destination_ip, forward_ip, port, str(e)))
                    raise
                else:
                    self._logger.debug("Kernel consistency check failed at enable_dnat, retrying..")
                    time.sleep(0.05)

    def _disable_dnat(self, destination_ip, forward_ip, port):
        while True:
            try:
                subchain = self._get_pre_subchain(port)
                subchain.table.restart()

                rule_rtp = self._create_dnat(destination_ip, forward_ip, port)
                rule_rtcp = self._create_dnat(destination_ip, forward_ip, port + 1)

                if subchain is None:
                    self._logger.error("Failed to get subchain for port %d" % port)
                    break

                subchain.delete_rule(rule_rtp);
                subchain.delete_rule(rule_rtcp);

                subchain.table.close()
                break

            except (iptc.IPTCError, iptc.XTablesError), e:
                if ctypes.get_errno() != errno.EAGAIN:
                    self._logger.error("Failed to disable DNAT destination_ip:%s forward_ip:%s port:%d err:%s" % (destination_ip, forward_ip, port, str(e)))
                    raise
                else:
                    self._logger.debug("Kernel consistency check failed at disable_dnat, retrying..")
                    time.sleep(0.05)

    def enable(self, destination_ip, forward_ip, rtp_port):
        """
        Note that we used to use the source_ip to setup the
        DNAT rules, but we've decided there is no point as the source ip
        can easily be spoofed for UDP packets so it does not really
        provide a significant security improvement, and it does harm
        when working behind NAT, because the source_ip that
        we receive may be NATed, so our rule will not work because
        the NAT will change the source IP to something different

        We could use a separate event when the real IP is learnt later
        on (from the RTP stream), but, what's the point? no added security but
        just extra complexity and inconvenience
        """

        """
        Setup a pair of SNAT rules to change the source address of the internal
        media interfaces (forward_ip) to the public external IP (destination_ip)
        The rules are set based on the UDP source port

        Then we also need to accept the traffic coming from the media interface
        because by default we drop everything coming from media interfaces to avoid
        kernel connection tracking to kicking in before we have a chance to
        setup the SNAT/DNAT rules
        """
        self._enable_snat(destination_ip, rtp_port)
        self._enable_accept(forward_ip, rtp_port);

        """
        Setup a pair of DNAT rules to change the destination address of packets
        coming from the external interfaces (not a media interface) to route
        them (ip_forward) to the proper media interface IP
        The rules are set based on the UDP destination port (which identifies
        the media interface) and also based on destination IP, which should
        add a bit more of efficiency because the rule will only be attached to a
        given particular destination IP (which belongs to a given network interface)
        """
        self._enable_dnat(destination_ip, forward_ip, rtp_port)

        return True

    def disable(self, destination_ip, forward_ip, rtp_port):
        """
        Tear down all the rules we created during enable()
        """

        self._disable_snat(destination_ip, rtp_port)
        self._disable_accept(forward_ip, rtp_port);

        self._disable_dnat(destination_ip, forward_ip, rtp_port)

        return True

class MediaSessionListener(oswc.EventListener):

    def __init__(self, logger):
        self._logger = logger
        super(oswc.EventListener, self).__init__()

    def on_event(self, event):
        mediamon = get_mediamon()
        ip_action = get_ip_action()
        self._logger.debug("Received media session event %s\n" % event)
        session_id = event.get_header("session-id")
        status = event.get_header("status")
        source_ip = event.get_header("source-ip")
        destination_ip = event.get_header("destination-ip")
        forward_ip = event.get_header("forward-ip")
        rtp_port = eval(event.get_header("rtp-port"))

        if status == 'enable':
            mediamon._logger.debug("Enabling session %s,%s,%s,%s,%d" % (session_id, source_ip, destination_ip, forward_ip, rtp_port))
            DBAction.insert(session_id, source_ip, destination_ip, forward_ip, rtp_port)
            ip_action.enable(destination_ip, forward_ip, rtp_port)
        else:
            mediamon._logger.debug("Disabling session %s,%s,%s,%s,%d" % (session_id, source_ip, destination_ip, forward_ip, rtp_port))
            DBAction.delete(session_id)
            ip_action.disable(destination_ip, forward_ip, rtp_port)

class MediaMonitor(sngpy.DBService):

    _service_name = 'mediamon'

    def __init__(self, logger):

        self._logger = logger

        self._dbconf = dict()
        self._dbconf_params = ['connection-string', 'table-prefix']
        self._dbconn = None

        self._swconf = dict()
        self._swconf_params = ['connection-string']

        self._conf = dict()
        self._conf_params = ['num-subchains', 'dynamic-rules']

        self._sleep_interval = 1

        self.modules_table = ''
        self.sessions_table = ''

        self.sched = None

        self._oswc_conn = None

        self.sysctl_bin = '/sbin/sysctl'
        self.conntrack_bin = '/usr/sbin/conntrack'
        self.sngtc_bin = '/usr/local/nsc/bin/sngtc_tool'

        self.prerouting_chain_name = 'mediamon_prerouting'
        self.postrouting_chain_name = 'mediamon_postrouting'
        self.forward_chain_name = 'mediamon_forward'
        self._dynamic_rules = False

        self._external_interfaces = ni.interfaces()

        self._null = None

        self._conntrack_udp_timeout = 10
        self._conntrack_udp_timeout_stream = 40

        formatter = logging.Formatter(self._logformat)

        console_handler = logging.StreamHandler()
        console_handler.setFormatter(formatter)
        self._logger.addHandler(console_handler)

        self.sched = sngpy.Scheduler(time.time, time.sleep)

        self._null = open(os.devnull, 'w')

        # custom MediaMon options should be added here via parser.add_option()
        self.parser = OptionParser()
        self.parser.add_option("", "--init-database", action="store_true",
                            dest="init_database",
                            help="Initialize Database")

        self.parser.add_option("", "--show-status", action="store_true",
                            dest="show_status",
                            help="Show current sessions")

        self.parser.add_option("", "--global-src-ip", action="store",
                            dest="global_src_ip",
                            help="Set a global source IP ignoring interface addresses")

        super(MediaMonitor, self).__init__()

        # handle sighup to reload rules and block/unblock objects
        self.reload = False
        def sigreload(signum, frame):
            mediamon = get_mediamon()
            mediamon.reload = True
        signal.signal(signal.SIGHUP, sigreload)


    def _parse_db(self, params):
        self._logger.debug("Reading database parameters")
        for p in params:
            self._logger.debug("database: %s=%s" % (p.attrib['name'], p.attrib['value']))
            if p.attrib['name'] in self._dbconf_params:
                self._dbconf[p.attrib['name']] = p.attrib['value']
            else:
                raise ValueError, "Unknown database XML parameter %s" % p.attrib['name']
        if 'connection-string' not in self._dbconf:
                raise ValueError, "Missing database connection-string XML parameter"
        if 'table-prefix' not in self._dbconf:
                self._dbconf['table-prefix'] = ''
        self.modules_table = self._dbconf['table-prefix'] + "media_modules"
        self.sessions_table = self._dbconf['table-prefix'] + "media_sessions"

    def _parse_switch(self, params):
        self._logger.debug("Reading switch parameters")
        for p in params:
            self._logger.debug("switch: %s=%s" % (p.attrib['name'], p.attrib['value']))
            if p.attrib['name'] in self._swconf_params:
                self._swconf[p.attrib['name']] = p.attrib['value']
            else:
                raise ValueError, "Unknown switch XML parameter %s" % p.attrib['name']
        if 'connection-string' not in self._swconf:
                raise ValueError, "Missing switch connection-string XML parameter"

    def _parse_conf(self, params):
        self._logger.debug("Reading mediamon parameters")
        for p in params:
            self._logger.debug("mediamon: %s=%s" % (p.attrib['name'], p.attrib['value']))
            if p.attrib['name'] in self._conf_params:
                self._conf[p.attrib['name']] = p.attrib['value']
            else:
                raise ValueError, "Unknown mediamon XML parameter %s" % p.attrib['name']
        if 'num-subchains' not in self._conf:
                raise ValueError, "Missing mediamon num-subchains XML parameter"
        if 'dynamic-rules' in self._conf:
            self._dynamic_rules = self._is_true(self._conf['dynamic-rules'])


    def configure(self):
        ret = True
        try:
            tree = super(MediaMonitor,self).configure()

            conf = tree.find('mediamon')
            if conf is None:
                raise ValueError, "Missing <mediamon> configuration"
            else:
                params = self._get_params(conf)
                self._parse_conf(params)

            dbconf = tree.find('database')
            if dbconf is None: 
                raise ValueError, "Missing <database> configuration"
            else:
                params = self._get_params(dbconf)
                self._parse_db(params)

            swconf = tree.find('switch')
            if swconf is None:
                raise ValueError, "Missing <switch> configuration"
            else:
                params = self._get_params(swconf)
                self._parse_switch(params)

            if self._dynamic_rules and self.options.global_src_ip:
                self._logger.warning('--global-src-ip will be ignored when dynamic rules are enabled')

        except:
            self._logger.critical("Failed to configure service")
            exc_type, exc_value, exc_traceback = sys.exc_info()
            self._print_exception(exc_type, exc_value, exc_traceback)
            ret = False

        return ret

    def cmd_exec(self, cmd, errcheck=True):
        cmd_str = ' '.join(cmd)
        self._logger.debug("Executing command %s" % (cmd_str))
        rc = subprocess.call(cmd, stdout=self._null, stderr=self._null)
        if errcheck and rc:
            self._logger.error("Command %s returned %d" % (cmd_str, rc))
        else:
            self._logger.debug("Command %s returned %d" % (cmd_str, rc))
        return rc

    def init_database(self):
        # connection to the database takes care of initialization
        try:
            self._db_connect()
            return True
        except:
            return False

    def show_status(self):
        self._db_connect()
        query = "SELECT * from %s" % (mediamon.sessions_table)

        cursor = mediamon.sql_exec(query)
        num_sessions = 0
        print "========================================================================================================================="
        print "| Session ID                         \t| Source \t\t| Destination\t\t| Forward    \t\t| RTP   |"
        print "========================================================================================================================="
        for obj in cursor:
            print "| %s\t| %15s\t| %15s\t| %15s\t| %05s\t|" % (obj.session_id, obj.source_ip, obj.destination_ip, obj.forward_ip, obj.rtp_port)
            num_sessions = num_sessions + 1

        print "========================================================================================================================="
        print "| Number of sessions:%05d                                                                                              |" %num_sessions
        print "========================================================================================================================="
        return True


    def _housekeeping(self):
        self.sched.fast_run()
        
    def _connect(self):
        # Connection loop (wait until we can connect or the daemon is stopped)
        self._oswc_conn = None
        connection_string = self._swconf['connection-string']
        while self.daemon_alive:
            self._housekeeping()
            # connect to the event provider
            if self._oswc_conn is None:
                self._oswc_conn = oswc.create_connection(connection_string, logger=self._logger)
                if self._oswc_conn is None:
                    self._logger.error("Failed to create connection to %s" % (connection_string))
                    time.sleep(self._sleep_interval)
                    continue

                listener = MediaSessionListener(self._logger)
                listener.set_filter(['MEDIA_SESSION'])
                self._oswc_conn.add_event_listener(listener)

            if not self._oswc_conn.connect():
                self._logger.debug("Failed to connect to %s" % (connection_string))
                time.sleep(self._sleep_interval)
                continue

            self._logger.info("Connected to %s" % (connection_string))
            break

    def _db_connect(self):
        self._dbconn = None
        try:
            self._dbconn = pyodbc.connect(self._dbconf['connection-string'])
        except Exception, e:
            self._logger.error("Failed to connect to the database %s: %s" % (self._dbconf['connection-string'], str(e)))
            raise

        # initialize the database
        self._db_init()

    def _db_init(self):
        try:
            # create media_modules tables
            dbquery = "CREATE TABLE IF NOT EXISTS `%s` ("  """
                      `id` INT UNSIGNED NOT NULL AUTO_INCREMENT ,
                      `ip_address` TEXT NOT NULL ,
                      `access_interface` TEXT NOT NULL ,
                      `port_range_start` INT UNSIGNED NOT NULL ,
                      `port_range_stop` INT UNSIGNED NOT NULL ,
                      PRIMARY KEY (`id`) ,
                      UNIQUE INDEX `id_UNIQUE` (`id` ASC) )
                   ENGINE = InnoDB;
                   """ % (self.modules_table)

            self.sql_exec(dbquery)

            # create media_sessions tables
            dbquery = "CREATE TABLE IF NOT EXISTS `%s` ("  """
                      `id` INT UNSIGNED NOT NULL AUTO_INCREMENT ,
                      `session_id` VARCHAR(255) NULL ,
                      `source_ip` TEXT NOT NULL ,
                      `destination_ip` TEXT NOT NULL ,
                      `forward_ip` TEXT NOT NULL ,
                      `rtp_port` INT UNSIGNED NOT NULL ,
                      PRIMARY KEY (`id`) ,
                      UNIQUE INDEX `id_UNIQUE` (`id` ASC) )
                   ENGINE = InnoDB;
                   """ % (self.sessions_table)

            self.sql_exec(dbquery)
        except:
            self._logger.error("Failed to initialize database")
            raise

    def _setup_media_sessions(self):
        mediamon = get_mediamon()

        mediamon._logger.info("Enabling existing media sessions")

        query = "SELECT * from %s" % (mediamon.sessions_table)

        cursor = mediamon.sql_exec(query)
        for obj in cursor:
            mediamon._logger.info("Verifying existing session with uuid: %s" % obj.session_id)
            cmd_args = "%s local_media_ip remote_media_ip sng_forward_ip sng_forward_port" % (str(obj.session_id))
            callback_data = "session_info:%s" % obj.session_id
            mediamon._oswc_conn.cmd_exec("uuid_getvars", cmd_args, mediamon.cmd_callback, callback_data)


    def _flush_media_sessions(self):
        mediamon = get_mediamon()
        ip_action = get_ip_action()

        mediamon._logger.info("Disabling existing media sessions")

        query = "SELECT * from %s" % (mediamon.sessions_table)

        cursor = mediamon.sql_exec(query)

        #Build list of access_interfaces
        for obj in cursor:
            ip_action.disable(obj.destination_ip, obj.forward_ip, obj.rtp_port)

    def _remove_static_rules(self):
        ip_action = get_ip_action()
        if self._dynamic_rules:
            ip_action.remove_subchains()

    def _disable_ext_ports(self):
        query = "SELECT DISTINCT access_interface FROM %s" % (self.modules_table)
        module_cursor = mediamon.sql_exec(query)

        for obj in module_cursor:
            self._logger.debug("Disabling external RJ-45's on %s" % obj.access_interface)
            cmd = [self.sngtc_bin, '-dev', str(obj.access_interface), '-disable-ext-ports' ]
            # We currently do not know if we have a D500, therefore we expect this command to fail
            # with D100, D150 etc
            self.cmd_exec(cmd, False)

    def _enable_ext_ports(self):
        query = "SELECT DISTINCT access_interface FROM %s" % (self.modules_table)
        module_cursor = mediamon.sql_exec(query)
        for obj in module_cursor:
            self._logger.debug("Enabling external RJ-45's on %s" % obj.access_interface)
            cmd = [self.sngtc_bin, '-dev', str(obj.access_interface), '-enable-ext-ports' ]
            # We currently do not know if we have a D500, therefore we expect this command to fail
            # with D100, D150 etc
            self.cmd_exec(cmd, False)

    def _setup_static_rules(self):
        ip_action = get_ip_action()
        access_interfaces = [];
        external_interfaces = [];

        query = "SELECT * FROM %s" % (self.modules_table)
        module_cursor = mediamon.sql_exec(query)

        #Build list of access_interfaces
        for obj in module_cursor:
            if obj.access_interface not in access_interfaces:
                access_interfaces.append(obj.access_interface)

        #remove access_interfaces from network interfaces to get external_interfaces
        for netif in ni.interfaces():
            if netif not in access_interfaces and netif != "lo":
                external_interfaces.append(netif)

        #for some reason, module cursor gets overwritten here, so we need to query again
        module_cursor = mediamon.sql_exec(query)

        for obj in module_cursor:
            if self._dynamic_rules:
                self._logger.debug("Adding static rules for %s:%d-%d" % (obj.ip_address, obj.port_range_start, obj.port_range_stop))
                ip_action.enable_outbound_drop(obj.ip_address)

                ip_action.enable_forward(obj.port_range_start, obj.port_range_stop)

                for netif in external_interfaces:
                    ip_action.enable_inbound_drop(netif, obj.port_range_start, obj.port_range_stop)

                ip_action.num_subchains = eval(mediamon._conf['num-subchains'])
                ip_action.subchain_size = ((obj.port_range_stop - obj.port_range_start) + 1) / ip_action.num_subchains

                #in terms of functionality it makes no difference to traverse the
                #list in reverse, but it just makes it easier to debug when looking at iptables

                for i in reversed(range(ip_action.num_subchains)):
                    subchain_port_range_start = (i * ip_action.subchain_size) + obj.port_range_start
                    subchain_port_range_stop = subchain_port_range_start + ip_action.subchain_size - 1
                    ip_action.create_pre_subchain(subchain_port_range_start, subchain_port_range_stop)
                    ip_action.create_post_subchain(subchain_port_range_start, subchain_port_range_stop)
            else:
                self._logger.debug("Adding static fixed rules for %s:%d-%d" % (obj.ip_address, obj.port_range_start, obj.port_range_stop))
                ip_action.enable_forward(obj.port_range_start, obj.port_range_stop)
                for netif in external_interfaces:
                    ip_action.enable_outbound_snat(netif, obj.ip_address)
                    ip_action.enable_inbound_dnat(netif, obj.ip_address, obj.port_range_start, obj.port_range_stop)
                for netif in access_interfaces:
                    for external_netif in external_interfaces:
                        ip_action.enable_loop_dnat(netif, external_netif, obj.ip_address, obj.port_range_start, obj.port_range_stop)


    def _flush_conntrack(self):
        self._logger.debug("Flushing connection tracker")
        cmd = [self.conntrack_bin, '-D', '-p', 'udp']
        # When no conntrack data is available it is normal we get err 1
        rc = self.cmd_exec(cmd, False)
        if rc > 1:
            self._logger.error("Failed to flush connection tracker")

    def _setup_conntrack(self):
        cmd_string = "net.netfilter.nf_conntrack_udp_timeout=%d" % self._conntrack_udp_timeout
        cmd = [self.sysctl_bin, '-w', cmd_string]
        self.cmd_exec(cmd)

        cmd_string = "net.netfilter.nf_conntrack_udp_timeout_stream=%d" % self._conntrack_udp_timeout_stream
        cmd = [self.sysctl_bin, '-w', cmd_string]
        self.cmd_exec(cmd)

    def cmd_callback(self, result = None, obj = None):
        #obj has format = "<cmd type>: <cmd info>"
        #e.g: session_info:<session_id>
        if obj is None:
            return

        args_r = re.compile('[:]+')
        obj_args = args_r.split(obj)

        #This could be abstracted into a dictionary of callback handlers later on
        if obj_args[0] == 'session_info':
            if result[:4] == "-ERR":
                self._logger.debug("Session with id:%s does not exist anymore" %obj_args[1])
                DBAction.delete(obj_args[1])
                return

            vars_r = re.compile('[\n]')
            variable_lines = vars_r.split(result)

            source_ip = None
            destination_ip = None
            forward_ip = None
            rtp_port = 0

            for variable_line in variable_lines:
                var_r = re.compile('[:]')
                variable_pair = var_r.split(variable_line)

                if variable_pair[0] == "local_media_ip":
                    destination_ip = variable_pair[1]
                elif variable_pair[0] == "remote_media_ip":
                    source_ip = variable_pair[1]
                elif variable_pair[0] == "sng_forward_ip":
                    forward_ip = variable_pair[1]
                elif variable_pair[0] == "sng_forward_port":
                    rtp_port = eval(variable_pair[1])

            if source_ip is None or destination_ip is None or forward_ip is None or rtp_port is 0:
                self._logger.error("Failed to obtain all required parameters for session: %s source_ip:%s destination_ip:%s forward_ip:%s rtp_port:%d" % (obj_args[1], source_ip, destination_ip, forward_ip, rtp_port))

                DBAction.delete(obj_args[1])
                return

            ip_action = get_ip_action()
            ip_action.enable(destination_ip, forward_ip, rtp_port)
            DBAction.update(obj_args[1], source_ip, destination_ip, forward_ip, rtp_port)

        else:
            self._logger.error("Don't how how to handle callback for command:%s" % obj_args[0])

        return

    def run(self):
        super(MediaMonitor, self).run()
        ip_action = get_ip_action()

        # connect to the database
        self._db_connect()

        # Connect to the event source only if dynamic rules are enabled
        if self._dynamic_rules:
            self._connect()

        self._setup_conntrack()

        self._disable_ext_ports()

        ip_action.setup_chains()
        ip_action.flush_chains()

        self._setup_static_rules()

        self._flush_conntrack()

        if self._dynamic_rules:
            self._setup_media_sessions()

        self._logger.info("%s is now running" % (self._service_name))

        # Event loop
        while self.daemon_alive:
            try:
                if self._dynamic_rules:
                    self._housekeeping()
                    sched_sleep = self.sched.next_event_time_delta()
                    if sched_sleep is 0:
                        """
                        You'd think we could run a housekeeping
                        if sched_sleep is zero, but apparently
                        sometimes time goes back and we don't
                        want to call receive_event with 0 as
                        that seems to (oddly) block, just wait
                        one more second to be on the safe side
                        """
                        sched_sleep = 1;

                    e = self._oswc_conn.receive_event(timeout=sched_sleep*1000)
                    if e is not None and str(e) == 'SERVER_DISCONNECTED':
                        self._logger.info("Server connection lost")
                        self._flush_media_sessions()
                        self._connect()
                        self._setup_media_sessions()
                else:
                    time.sleep(1)

            except KeyboardInterrupt:
                self._logger.info("Stopping %s. User aborted." %
                    (self._service_name))
                break
            except:
                exc_type, exc_value, exc_traceback = sys.exc_info()
                self._print_exception(exc_type, exc_value, exc_traceback)
                time.sleep(self._sleep_interval)

        self._logger.info("%s is now terminating" % (self._service_name))

        if self._dynamic_rules:
            self._flush_media_sessions()

        ip_action.flush_chains(False)
        self._remove_static_rules()
        ip_action.remove_chains()

        self._enable_ext_ports()


logger = logging.getLogger(MediaMonitor._service_name)
logger.setLevel(logging.DEBUG)

## main() ##

def get_mediamon():
    return mediamon

def get_ip_action():
    return ip_action

mediamon = None
ip_action = None

try:
    mediamon = MediaMonitor(logger)
    ip_action = IPTableAction(mediamon)

    if mediamon.options.stop is not None:
        mediamon.stop()
        sys.exit(0)

    if mediamon.options.restart is not None:
        mediamon.restart()
        sys.exit(0)

    if mediamon.options.conf_path is None:
        # convenient way to retrieve -c option when service is running
        mediamon.options.conf_path = sngpy.Service.find_conf_path()

    if mediamon.options.conf_path is None:
        mediamon.parser.print_help()
        mediamon.parser.error("-c is required to find the configuration path")
        sys.exit(1)

    if mediamon.configure() is False:
        mediamon.parser.error("Failed to configure daemon using file %s" % mediamon.options.conf_path)
        sys.exit(1)

    if mediamon.options.init_database:
        if mediamon.init_database():
            sys.exit(0)
        else:
            sys.exit(1)

    if mediamon.options.show_status:
        if mediamon.show_status():
            sys.exit(0)
        else:
            sys.exit(1)

    # Following options (either start or run) should not be executed if the pid file exists
    if os.path.exists(mediamon.pidfile):
        logger.error("Service seems to be running already, pid file %s already exists" % mediamon.pidfile)
        sys.exit(1)

    # Decide whether to run in the background (Daemon mode) or foreground
    if mediamon.options.start is not None:
        mediamon.start()
    else:
        mediamon.run()

    sys.exit(0)
except KeyboardInterrupt:
    logger.info("Received keyboard interrupt, exiting.")
    sys.exit(0)
except SystemExit, e:
    """
    We just catch this so it won't end up in the catch-all below
    but we still must raise the exception if we want python to
    exit with the provided sys.exit() return code
    """
    if not sys.stdin.isatty():
        logger.info("SystemExit status %d" % e.code)
    raise
except:
    exc_type, exc_value, exc_traceback = sys.exc_info()
    sngpy.print_exception(exc_type, exc_value, exc_traceback)
    if not sys.stdin.isatty():
        logger.error("Unexpected exception %s/%s, aborting." % exc_type % exc_value)
    sys.exit(1)

