#!/usr/bin/env python
# vim: tabstop=4 softtabstop=4 shiftwidth=4 textwidth=80 smarttab expandtab
"""
* Copyright (C) 2014  Sangoma Technologies Corp.
* All Rights Reserved.
*
* Author:
* Leonardo Lang <lang@sangoma.com>
*
* This code is Sangoma Technologies Confidential Property.
* Use of and access to this code is covered by a previously executed
* non-disclosure agreement between Sangoma Technologies and the Recipient.
* This code is being supplied for evaluation purposes only and is not to be
* used for any other purpose.
"""

# monotonic clock
__all__ = ["monotonic_time"]

import ctypes
import os

CLOCK_MONOTONIC_RAW = 4 # see <linux/time.h>

class timespec(ctypes.Structure):
    _fields_ = [
        ('tv_sec', ctypes.c_long),
        ('tv_nsec', ctypes.c_long)
    ]

librt = ctypes.CDLL('librt.so.1', use_errno=True)
clock_gettime = librt.clock_gettime
clock_gettime.argtypes = [ctypes.c_int, ctypes.POINTER(timespec)]

def monotonic_time():
    t = timespec()
    if clock_gettime(CLOCK_MONOTONIC_RAW , ctypes.pointer(t)) != 0:
        errno_ = ctypes.get_errno()
        raise OSError(errno_, os.strerror(errno_))
    return t.tv_sec + t.tv_nsec * 1e-9

class StopRequestException(Exception):
    pass

import re
import sys
import time
import errno
import signal
import syslog
import logging
import subprocess
import optparse

regifaces = re.compile('^[0-9]+: ([^:]+): <(.+)> (.+)')
regsplits = re.compile('[, ]')

class DirectSyslogHandler(logging.Handler):
    def __init__(self):
        self.levelmap = { 'DEBUG': syslog.LOG_DEBUG, 'INFO': syslog.LOG_INFO, 'NOTICE': syslog.LOG_NOTICE,
                          'WARNING': syslog.LOG_WARNING, 'ERROR': syslog.LOG_ERR, 'CRITICAL': syslog.LOG_CRIT }
        logging.Handler.__init__(self)

    def emit(self, record):
        try:            
            msg = self.format(record)
            level = self.levelmap.get(record.levelname, syslog.LOG_INFO)
            syslog.syslog(level, msg)
        except Exception, e:
            pass

def create_logger(debug=None):
    progname = os.path.basename(os.path.abspath(sys.argv[0]))

    logger = logging.getLogger(progname)
    logger.setLevel(logging.INFO)

    handler = DirectSyslogHandler()
    handler.setFormatter(logging.Formatter(progname + '[%(process)d]: %(levelname)s: ' + '%(message)s'))
    logger.addHandler(handler)

    if debug is not None:
        console_handler = logging.StreamHandler(sys.stderr)
        console_handler.setFormatter(logging.Formatter(progname + '[%(process)d]: %(levelname)s: ' + '%(message)s'))
        logger.addHandler(console_handler)
        logger.setLevel(logging.DEBUG)

    return logger

# custom SnortConfUpdater options should be added here via parser.add_option()
parser = optparse.OptionParser()

parser.add_option("", "--debug",  dest="debug",  action='store_true',
                  help="Enable debug mode")
parser.add_option("", "--interfaces",  dest="interfaces", metavar='IFACES',
                  help="Set allowed interfaces list to IFACES")

parser.add_option("", "--start", dest="daemon_start", action='store_true',
                  help="Start daemon")
parser.add_option("", "--stop", dest="daemon_stop", action='store_true',
                  help="Stop daemon")

(options, args) = parser.parse_args()

logger = create_logger(options.debug)

if options.daemon_start is not None and options.daemon_stop is not None:
    sys.stderr.write('Cannot execute with both --start and --stop.')
    sys.exit(1)

def get_ifaces(ifacelist):
    states = subprocess.Popen([ 'ip', 'link' ], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    sout, serr = states.communicate()

    if ifacelist is not None:
        ifacelist = regsplits.split(ifacelist)

    names = {}

    for ln in sout.splitlines():
        m1 = regifaces.match(ln)
        if m1 is not None:
            if ifacelist is not None:
                if not m1.group(1) in ifacelist:
                    continue
            else:
                if not m1.group(1).startswith('eth'):
                    continue

            if 'UP' in m1.group(2).split(','):
                names[m1.group(1)] = True
            else:
                names[m1.group(1)] = False

    return names

def run_snort(iface, cmdargs):
    logger.info('Starting snort on "' + iface + '"...')

    pid = os.fork()

    args = [ 'snort', '-i', iface, '--create-pidfile' ]
    args.extend(cmdargs)

    snort_prog = '/usr/local/nsc/bin/snort'

    logger.debug('calling "%s" with: %s' % (snort_prog, str(args)))

    if pid == 0:
        try:
            fout = open('/var/log/snort/snortmessages_' + iface + '.log', 'a')
            fnul = open('/dev/null', 'r')

            os.dup2(fnul.fileno(), 0)
            os.dup2(fout.fileno(), 1)
            os.dup2(fout.fileno(), 2)
        except:
            pass

        os.execvp(snort_prog, args)
        sys.exit(123)

    return pid

def exit_handler(signum, frame):
    raise StopRequestException

def main(cmdargs, ifacelist):

    POLL_TIME = 3

    MAX_RESTART_COUNT = 5
    MIN_INTERVAL_TIME = 30
    PROC_DISABLE_TIME = 60

    signal.signal(signal.SIGTERM, exit_handler)
    signal.signal(signal.SIGINT,  exit_handler)

    proc_fail = set()
    proc_pids = {}
    exec_time = {}
    exec_wait = {}

    try:
        while True:
            for name, status in get_ifaces(ifacelist).items():
                if proc_pids.get(name) is not None:
                    continue

                if status:
                    if name in proc_fail:
                        proc_fail.remove(name)

                    sec_now = monotonic_time()

                    if exec_wait.get(name) is not None:
                        tmp = max(0, exec_wait.get(name) - POLL_TIME)
                        if tmp > 0:
                            exec_wait[name] = tmp
                            continue
                        else:
                            del exec_wait[name]

                    if len(exec_time.get(name, [])) > MAX_RESTART_COUNT:
                        sec_fst = exec_time[name].pop()

                        if (sec_now - sec_fst) < MIN_INTERVAL_TIME:
                            logger.warning('Snort restarted more than %d times in the last %d seconds, disabling for %d minutes.' %
                                (MAX_RESTART_COUNT, MIN_INTERVAL_TIME, PROC_DISABLE_TIME))
                            exec_wait[name] = PROC_DISABLE_TIME * 60
                            continue

                    tmp = exec_time.get(name, [])
                    tmp.append(sec_now)
                    exec_time[name] = tmp

                    pid = run_snort(name, cmdargs)
                    if pid is not None:
                        proc_pids[name] = pid

                else:
                    if name not in proc_fail:
                        logger.info('Delaying snort start on "%s"...' % name)
                        proc_fail.add(name)

            try:
                (pid, status) = os.waitpid(-1, os.WNOHANG)

                if pid > 0:
                    name = None
                    for iname, ipid in proc_pids.items():
                        if pid == ipid:
                            name = iname
                            break

                    if name is not None:
                        if os.WIFEXITED(status):
                            logger.warning('Snort on "%s" exited with code %d, restarting..' % (name, os.WEXITSTATUS(status)))
                        elif os.WIFSIGNALED(status):
                            logger.warning('Snort on "%s" killed with signal %d, restarting..' % (name, os.WTERMSIG(status)))

                        del proc_pids[name] # this will trigger a restart

                    else:
                        logger.warning('Unknown child (pid=%d) terminated with status %d.' % (pid, status))

            except OSError, e:
                if e.errno <> errno.EAGAIN and e.errno <> errno.ECHILD:
                    logger.warning('waitpid(): %s' % e.strerror)

            time.sleep(POLL_TIME)

    except StopRequestException:
        for name, pid in proc_pids.items():
            logger.info('Killing snort %d at interface %s...' % (pid, name))
            try:
                os.kill(pid, signal.SIGTERM)
            except:
                pass

        for name, pid in proc_pids.items():
            try:
                os.waitpid(pid, 0)
            except OSError, e:
                if e.errno <> errno.EAGAIN:
                    logger.warning('Failed to waitpid(%d): %s' % (pid, e.strerror))

pidfname = '/var/run/snort-watchdog.pid'

if options.daemon_start is not None:
    try:
        if os.fork() <> 0:
            sys.exit(0)

    except Exception, e:
        logger.error('Failed to fork(): %s' % e.strerror)
        sys.exit(1)

    try:
        os.setsid()
    except Exception, e:
        logger.warning('Failed to setsid(): %s' % e.strerror)

    try:
        fp = open(pidfname, 'w')
        fp.write("%s\n" % os.getpid())
        fp.close()
    except Exception, e:
        logger.error('Failed to open(%s): %s' % (pidfname, e.strerror))
        sys.exit(1)

    main(args, options.interfaces)

    try:
        os.unlink(pidfname)
    except:
        pass

    sys.exit(0)

if options.daemon_stop is not None:
    pidnum = None
    retnum = 1

    try:
        fp = open(pidfname, 'r')
        data = fp.read(1024).rstrip('\n')
        fp.close()

        pidnum = int(data)

    except Exception, e:
        logger.debug('File %s not found or invalid: %s.' % (pidfname, str(e)))
        sys.exit(retnum)

    if pidnum is not None:
        if os.path.isdir('/proc/%d' % pidnum):
            os.kill(pidnum, signal.SIGTERM)

        for i in range(0,15):
            if not os.path.isdir('/proc/%d' % pidnum):
                break
            time.sleep(1)

        if os.path.isdir('/proc/%d' % pidnum):
            logger.info('Unable to stop snort watchdog daemon (pid=%d).' % pidnum)
        else:
            retnum = 0

    sys.exit(retnum)

# without arguments
main(args, options.interfaces)
