#!/usr/bin/env python
# -*- mode: Python; -*-
#
# Authors: John Dennis <jdennis@redhat.com>
#
# Copyright (C) 2007,2008 Red Hat, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
#

import audit
import fnmatch
import gettext
import getopt
import glob
import gobject
import os
import time
import re
import socket as Socket
import sys
import subprocess

prog_name = 'test_setroubleshootd'

from setroubleshoot.config import parse_config_setting, get_config
gettext.install(domain=prog_name, unicode=False, codeset='utf-8')
from setroubleshoot.log import log_init
log_init(sys.argv[0], {'console': True,
                       'level': 'error'})
from setroubleshoot.log import *

from setroubleshoot.errcode import *
from setroubleshoot.rpc import *
from setroubleshoot.server import *
from setroubleshoot.audit_data import *

import client_rpc

verbose = False
use_auparse = False
prompt = "> "
prev_command = None
test_files = {}
test_file_names = []
test_file_index = 0
connection_pool = ConnectionPool()

if use_auparse:
    import auparse

#------------------------------ Variables -------------------------------

try:
    getattr(audit, 'AUDIT_EOE')
except AttributeError:
    audit.AUDIT_EOE = 1320

srcfile_glob = '*'

binary_audit_socket_path = '/tmp/audit_events'
text_audit_socket_path = '/tmp/audispd_events'

record_format_name = 'host_text'
record_format = None
audit_socket_path = text_audit_socket_path

output_dst = 'unix'
record_format_map = {
    'text': AuditRecordReader.TEXT_FORMAT,
    'host_text': AuditRecordReader.TEXT_FORMAT,
    'binary': AuditRecordReader.BINARY_FORMAT
}
record_formatter_map = {
    'text': AuditRecord.to_text,
    'host_text': AuditRecord.to_host_text,
    'binary': AuditRecord.to_binary,
}
protocol_formats_to_text = {
    AuditRecordReader.TEXT_FORMAT: 'text',
    AuditRecordReader.BINARY_FORMAT: 'binary',
}

last_audit_record = None        # FIXME: should be passed as parameter, not global

#-----------------------------------------------------------------------------


class SubProcess(gobject.GObject):
        #def on_exit(subprocess, status):
        #    print "exit(%s) %s" % (status, subprocess.cmd)
        #def readline(subprocess, io_name, line):
        #    print "readline(%s): %s" % (io_name, line),
        #subprocess = SubProcess('cat run_daemon')
        #subprocess.connect('exit', on_exit)
        #subprocess.connect('readline', readline)
        #subprocess.run()

    __gsignals__ = {
        'exit':  # callback(subprocess, status)
        (gobject.SIGNAL_RUN_LAST, gobject.TYPE_NONE, (gobject.TYPE_INT,)),
        'readline':  # callback(subprocess, io_name, line)
        (gobject.SIGNAL_RUN_LAST, gobject.TYPE_NONE, (gobject.TYPE_STRING, gobject.TYPE_STRING,)),
    }

    class IO:

        def __init__(self):
            self.watch_id = None
            self.buf = ''

    def __init__(self, cmd):
        gobject.GObject.__init__(self)
        self.cmd = cmd
        self.stdout_buf = ''
        self.stderr_buf = ''
        self.bufsize = 1024

        self.sub_process = None
        self.status = None

        self.io = {}
        self.io['stdout'] = self.IO()
        self.io['stderr'] = self.IO()

    def _do_exit(self):
        if self.status is None:
            self.status = self.sub_process.wait()
            self.emit('exit', self.status)

    def handle_io(self, file, io_condition, io_name):
        if io_condition & gobject.IO_IN:
            buf = self.io[io_name].buf
            data = file.read(self.bufsize)
            buf += data
            # burst lines
            start = 0
            end = buf.find('\n', start)
            while end >= 0:
                end += 1                # include newline
                line = buf[start:end]
                self.emit('readline', io_name, line)
                start = end
                end = buf.find('\n', start)

            self.io[io_name].buf = buf[start:]
        elif io_condition & (gobject.IO_HUP | gobject.IO_ERR):
            self._do_exit()
            return False

        return True

    def process_data(self, io_name, data):
        start = 0
        end = data.find('\n', start)
        while end >= 0:
            end += 1                # include newline
            line = data[start:end]
            self.emit('readline', io_name, line)
            start = end
            end = data.find('\n', start)

        data = data[start:]
        return data

    def run(self):
        self.sub_process = subprocess.Popen(self.cmd,
                                            stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                                            close_fds=True, shell=True)

        self.io['stdout'].watch_id = gobject.io_add_watch(self.sub_process.stdout,
                                                          gobject.IO_IN | gobject.IO_HUP | gobject.IO_ERR,
                                                          self.handle_io, 'stdout')
        self.io['stderr'].watch_id = gobject.io_add_watch(self.sub_process.stderr,
                                                          gobject.IO_IN | gobject.IO_HUP | gobject.IO_ERR,
                                                          self.handle_io, 'stderr')


gobject.type_register(SubProcess)

#================================== AuParse ==================================


def auparse_feed_callback(au, cb_event_type, audit_event_handler):
    if debug:
        log_avc.debug("auparse_feed_callback()")
    try:

        if cb_event_type == auparse.AUPARSE_CB_EVENT_READY:
            if not au.first_record():
                log_avc.error("auparse_feed_callback(): failed to position to first record")
                return

            timestamp = au.get_timestamp()
            if timestamp is None:
                log_avc.error("auparse_feed_callback(): failed to get timestamp")
                return

            event_id = AuditEventID(timestamp.sec, timestamp.milli, timestamp.serial, timestamp.host)

            while True:
                record_text = au.get_record_text()
                body_text = parse_audit_record_text(record_text)[3]
                record_type = au.get_type()
                line_number = au.get_line_number()

                au.first_field()
                fields = {}

                while True:
                    fields[au.get_field_name()] = au.get_field_str()
                    if not au.next_field():
                        break

                audit_event_handler(record_type, event_id, body_text, fields, line_number)

                if not au.next_record():
                    break

    except Exception, e:
        log_avc.exception("auparse_feed_callback() exception %s: %s", e.__class__.__name__, str(e))
        return

#================================== AuParse ==================================

#-----------------------------------------------------------------------------


class AuditClientConnectionHandler(ClientConnectionHandler):

    def __init__(self, socket_address):
        ClientConnectionHandler.__init__(self, socket_address)

    def on_connection_state_change(self, connection_state, flags, flags_added, flags_removed):
        ClientConnectionHandler.on_connection_state_change(self, connection_state, flags, flags_added, flags_removed)

        if flags_removed & ConnectionState.OPEN:
            print "client disconnect"
            connection_pool.remove_client(self)

        if flags_added & ConnectionState.OPEN:
            print "client connect"
            connection_pool.add_client(self)
            self.out_fd = self.socket_address.socket.makefile()
            #idle_proc_id = gobject.idle_add(feed_clients().next)

#-----------------------------------------------------------------------------


def get_test_files_in_dir(dir, filter_glob=None, extension='.log'):

    #log_program.debug("get_test_files_in_dir: filter_glob=%s, extension=%s dir=%s", filter_glob, extension, dir)

    if filter_glob is None:
        filter_glob = '*'
    filter_glob = re.sub(extension, '', filter_glob)

    files = []
    for file in glob.glob(os.path.join(dir, filter_glob + extension)):
        files.append(file)

    files.sort()
    return files


def send_log_file(filepath, out_fd):
    global last_audit_record
    comment_re = re.compile('#.*')
    directive_re = re.compile(r'^\s*%\s*(\w+)\s*=(.*)')
    last_audit_record = None

    def new_input_line(f):
        new_data = f.readline()
        if new_data == '':
            return None
        new_data = comment_re.sub('', new_data)
        match = directive_re.search(new_data)
        if match:
            name = match.group(1)
            value = match.group(2)
            #print "got processing directive %s = %s" % (name, value)
            return ''
        return new_data

    def new_audit_record_handler(record_type, event_id, body_text, fields, line_number):
        global last_audit_record

        record = AuditRecord(record_type, event_id, body_text, fields, line_number)
        last_audit_record = record
        output = record_formatter(record)
        if debug:
            log_avc.debug("output: %s", output)

        os.write(out_fd.fileno(), output)
        last_audit_record = record

    print >> sys.stderr, "sending: %s" % os.path.basename(filepath)
    try:
        f = open(filepath)
    except Exception, e:
        program.error("cannot open '%s' [%s]", filepath, e.strerror)
        return

    if use_auparse:
        auparser = auparse.AuParser(auparse.AUSOURCE_FEED)
        auparser.add_callback(auparse_feed_callback, new_audit_record_handler)

        while True:
            new_data = new_input_line(f)
            if new_data is None:
                break
            auparser.feed(new_data)

        # All done, flush all buffered events out
        auparser.flush_feed()

    else:
        record_reader = AuditRecordReader(AuditRecordReader.TEXT_FORMAT)

        while True:
            new_data = new_input_line(f)
            if new_data is None:
                break
            for (record_type, event_id, body_text, fields, line_number) in record_reader.feed(new_data):
                new_audit_record_handler(record_type, event_id, body_text, fields, line_number)

    # Send audit end of event to assure immediate sequential processing
    if last_audit_record is not None:
        end_of_event_record = AuditRecord('EOE', last_audit_record.event_id, '', {}, last_audit_record.line_number)
        output = record_formatter(end_of_event_record)
        os.write(out_fd.fileno(), output)

    f.close


def feed_clients():
    for client in connection_pool.clients():
        out_fd = client.out_fd
        filepaths = test_files.values()
        filepaths.sort()
        for filepath in filepaths:
            send_log_file(filepath, out_fd)
            time.sleep(4)           # FIXME, when AUDIT_EOE works we won't need this
            yield True
        yield False


def send_file_names(file_names):
    if output_dst == 'stdout':
        send_file_names_to_fd(file_names, sys.stdout)
    elif output_dst == 'unix':
        send_file_names_to_clients(file_names)


def send_file_names_to_clients(file_names):
    clients = [client for client in connection_pool.clients()]
    if len(clients) == 0:
        print "no clients connected"
        return
    for client in clients:
        out_fd = client.out_fd
        send_file_names_to_fd(file_names, out_fd)


def send_file_names_to_fd(file_names, out_fd):
    for filename in file_names:
        filepath = test_files[filename]
        send_log_file(filepath, out_fd)
        #time.sleep(4)           # FIXME, when AUDIT_EOE works we won't need this


def input_prompt():
    if prev_command:
        print >> sys.stdout, prompt + "[%s] " % prev_command,
    else:
        print >> sys.stdout, prompt,
    sys.stdout.flush()


def handle_stdin(file, io_condition):
    process_command_input(file.readline())
    input_prompt()
    return True


def process_command_input(line):
    global test_file_index, prev_command
    line = line.strip()
    #log_program.debug("process_command_input: line=%s", line)
    words = line.split()
    if len(words) == 0:
        if prev_command:
            line = prev_command
            words = line.split()
        else:
            return

    prev_command = line
    cmd = words.pop(0)

    if cmd == 'send' or cmd == 's':
        if len(words) == 0:
            send_file_names(test_file_names)
        else:
            if words[0] == 'next':
                if len(test_file_names) == 0:
                    print "no test files"
                    return
                if test_file_index >= len(test_file_names):
                    test_file_index = 0
                send_file_names([test_file_names[test_file_index]])
                test_file_index += 1
            elif words[0] == 'all':
                send_file_names(test_file_names)
            else:
                files = {}
                for word in words:
                    matches = fnmatch.filter(test_file_names, word)
                    for match in matches:
                        files[match] = 1
                files = files.keys()
                if len(files) > 0:
                    print "sending files %s" % files
                    send_file_names(files)
                else:
                    print "no matches for %s" % ' '.join(words)
    elif cmd == 'list' or cmd == 'l':
        print test_file_names
    elif cmd == 'help' or cmd == 'h':
        print '''
send|s [all|next|<name>] default=all
                         all:    sends all the log files
                         next:   sends the next log file in the log file list
                         <name>: sends the <name> log file
list|l                   print the list of log files
quit|q                   exit the program
help|h                   print this information
'''
    elif cmd == 'quit' or cmd == 'q':
        sys.exit(0)
    else:
        print "command = '%s' not recognized" % cmd


#-----------------------------------------------------------------------------


test_file_dir = 'data/audit'


def usage():
    print '''
-a --audit-socket	audit message socket (default=%(audit_socket_path)s)
-o --out    		- for stdout, unix for AF_UNIX socket
-s --srcfiles		shell glob syntax specifying src log file to send
-f --format %(valid_record_formats)s	output audit record format
-c --config section.option=value	set a configuration value
-h --help		display help info
-D --test-file-dir	search for test files in this directory (default=%(test_file_dir)s)
-v --verbose		turn on verbose output

run setroubleshootd this way (add -V for console debug output):
setroubleshootd -f -c audit.text_protocol_socket_path=/tmp/audispd_events

''' % {'audit_socket_path'   : audit_socket_path,
       'valid_record_formats': '|'.join(record_format_map.keys()),
       'test_file_dir': test_file_dir
       }
try:
    opts, args = getopt.getopt(sys.argv[1:],
                               "a:o:s:f:c:hD:v",
                               ["audit-socket=", "out=", "srcfiles=", "format=", "config=", "help", "test-file-dir", "verbose"])
except getopt.GetoptError:
    # print help information and exit:
    usage()
    sys.exit(2)

for o, a in opts:
    if o in ("-o", "--out"):
        if a == '-':
            output_dst = 'stdout'
        elif a == 'stdout':
            output_dst = 'stdout'
        elif a == 'unix':
            output_dst = 'unix'
        else:
            print >> sys.stderr, "ERROR: output destination (%s) invalid, must be -|stdout|unix" % (a)
            sys.exit(1)

    if o in ("-s", "--srcfiles"):
        srcfile_glob = a

    if o in ("-D", "--test-file-dir"):
        test_file_dir = a

    if o in ("-a", "--socket"):
        audit_socket_path = a

    if o in ("-f", "--format"):
        if a not in record_format_map:
            print >> sys.stderr, "ERROR: record format (%s) invalid, must be one of %s" % \
                (a, ','.join(record_format_map.keys()))
            sys.exit(1)
        record_format_name = a

    if o in ("-c", "--config"):
        config_setting = a
        from setroubleshoot.config import parse_config_setting
        if not parse_config_setting(config_setting):
            log_program.error("could not parse config setting '%s'", config_setting)

    if o in ("-v", "--verbose"):
        verbose = True

    if o in ("-h", "--help"):
        usage()
        sys.exit()


record_format = record_format_map[record_format_name]
record_formatter = record_formatter_map[record_format_name]

listen_addresses = "{unix}%s" % (audit_socket_path)

test_files = {}
test_file_names = []
filepaths = get_test_files_in_dir(test_file_dir, srcfile_glob)
for filepath in filepaths:
    name = os.path.basename(filepath)
    name = os.path.splitext(name)[0]
    test_files[name] = filepath
test_file_names = test_files.keys()
test_file_names.sort()


if verbose:
    print "using test logs [%s]" % ' '.join(test_file_names)

if output_dst == 'stdout':
    send_file_names(test_file_names)
elif output_dst == 'unix':

    log_program.info("opening %s sockets in %s protocol format",
                     listen_addresses, protocol_formats_to_text[record_format])
    for listen_address in parse_socket_address_list(listen_addresses):
        listening_server = ListeningServer(listen_address, AuditClientConnectionHandler)
        listening_server.open()

    stdin_io_watch_id = gobject.io_add_watch(sys.stdin, gobject.IO_IN, handle_stdin)
    input_prompt()

    main_loop = gobject.MainLoop()
    main_loop.run()

else:
    log_program.error("unknown output destination (%s)", output_dst)
