#!/usr/bin/python3

import base64
import collections
import hashlib
import re
import sys

DEBUG = False
RE_HEADER = r'_ibdmp:[0-9]+/[0-9]+[|]chunks=[0-9]+,md5=[0-9a-f]+[|]'
RE_CHUNK = r'_ibdmp:[0-9]+/[0-9]+[|][0-9]+:[A-Za-z0-9+/=]+[|]'


def USAGE():
    self = sys.argv[0].split('/')[-1]
    lines = [
        "usage: %s path/to/console.log path/to/target.tar.xz" % self,
        "",
        "Decode debug tarball emited by leapp's initramfs in-band",
        "console debugger, ibdmp().",
    ]
    sys.stderr.writelines('%s\n' % l for l in lines)
    sys.exit(2)


def LOG_DEBUG(msg):
    if DEBUG:
        sys.stderr.write('DEBUG: %s\n' % msg)


def LOG_WARN(msg):
    sys.stderr.write('%s\n' % msg)


class IbdmpDecodeError(ValueError):
    pass


class UsageError(ValueError):
    pass


class _Chunk:

    @classmethod
    def from_raw1(cls, raw_chunk):
        if not raw_chunk.startswith('_ibdmp:'):
            LOG_WARN("invalid chunk payload (no '_ibdmp:'?): %s"
                     % raw_chunk)
            raise IbdmpDecodeError(raw_chunk)
        areas = raw_chunk.split('|')
        parts = areas[1].split(':')
        return cls(
            ordinal=int(parts[0]),
            payload=str(parts[1]),
        )

    def __init__(self, ordinal, payload):
        self.ordinal = ordinal
        self.payload = payload


class Header:
    """
    Chunk set header
    """

    @classmethod
    def from_rawN(cls, raw_headers):
        """
        Initialize chunk header from header chunk candidates

        raw_headers is a list of strings that contain encoded chunk
        parameters for the whole chunk set, ie. number of chunks, number
        of iterations, and MD5 hash of the content encoded in the chunk set.

        Raw header chunks can be corrupted so this factory will choose
        winner based on prevalence.

        For chunk set example in ChunkCounter.__init__ corresponding
        raw headers could look similar to this:

            _ibdmp:1/3|chunks=2,md5=281cc34e13cb4a502abd340fd07c4020|
            _ibdmp:2/3|chunks=2,md5=281cc34e13cb4a502abd340fd07c4020|
            _ibdmp:3/3|chun?s=2,md5=281cc34e13cb4a502abd340fd07c4020|

        In this case, the winner is the first and second one.
        """
        cntr = collections.Counter([
            Header._from_raw1(rh) for rh in raw_headers
        ])
        if not cntr:
            LOG_WARN("no dumps found in this console log")
            raise IbdmpDecodeError()
        winner = cntr.most_common()[0][0]
        LOG_DEBUG("header winner: %s" % winner)
        return winner

    @classmethod
    def _from_raw1(cls, raw_header):
        parts = raw_header.split('|')
        _, stats = parts[0].split(':')
        pairs = parts[1].split(',')
        if not pairs[0].startswith('chunks='):
            LOG_WARN("invalid header chunk payload (no chunks=?): %s"
                     % raw_header)
            raise IbdmpDecodeError(raw_header)
        if not pairs[1].startswith('md5='):
            LOG_WARN("invalid header chunk payload (no md5=?): %s"
                     % raw_header)
            raise IbdmpDecodeError(raw_header)
        return cls(
            chunks=int(pairs[0].split('=')[1]),
            md5=str(pairs[1].split('=')[1]),
            csets=int(stats.split('/')[1]),
        )

    def __init__(self, chunks, md5, csets):
        self.chunks = chunks
        self.md5 = md5
        self.csets = csets

    def __eq__(self, othr):
        return (self.chunks, self.md5) == (othr.chunks, othr.md5)

    def __hash__(self):
        return hash((self.chunks, self.md5))

    def __neq__(self, othr):
        return not self.__eq__(othr)

    def __str__(self):
        return ("Header(csets=%r,chunks=%r,md5=%r)"
                % (self.csets, self.chunks, self.md5))


class ChunkCounter:
    """
    Chunk collector

    Initialize with Header that you have some condfidence in
    (see Header.from_rawN), and set of raw chunks.

    Chunks could be corrupted but they should come in N replicated
    sets, so for every position in the chunk set, the initializer
    will select most prevalent variant of the given chunk.

    Eg. if chunk set was:

        _ibdmp:1/3|1:A/sl1cEofBASe64/|
        _ibdmp:1/3|2:paDD3d==========|
        _ibdmp:2/3|1:A/sl1cEofBASe64/|
        _ibdmp:2/3|2:paDD3d========!=|
        _ibdmp:3/3|1:A/sl1cEofBASe64/|
        _ibdmp:3/3|2:paDD3d==========|

    on position 2, the corrupted chunk will be removed.

    Use decode() to get the encoded tarball bytes, or decode_to()
    to write it to a file.
    """

    def __init__(self, header, raw_chunks):
        self.header = header
        self._bagset = collections.defaultdict(collections.Counter)
        LOG_DEBUG('header.chunks=%r' % header.chunks)
        for cr in raw_chunks:
            c = _Chunk.from_raw1(cr)
            LOG_DEBUG('c.ordinal=%r' % c.ordinal)
            self._bagset[c.ordinal].update([c.payload])

    @property
    def chunks(self):
        """
        Selected chunks from all known
        """
        out = []
        for idx in range(1, self.header.chunks + 1):
            cbag = self._bagset.get(idx)
            if not cbag:
                sys.stderr.write('Missing chunk id: %d/%d\n'
                                 % (idx, self.header.chunks))
                continue
            winner, score = cbag.most_common()[0]
            confidence = 100 * (score / self.header.csets)
            LOG_DEBUG("chunk position winner: %d: %s (%d%%)"
                      % (idx, winner, confidence))
            out.append(winner)
        return out

    def decode(self):
        """
        Decode tarball from valid chunk data
        """
        tarball = base64.b64decode(''.join(self.chunks))
        tarball_md5 = hashlib.md5(tarball).hexdigest()
        if not tarball_md5 == self.header.md5:
            LOG_WARN("MD5 mismatch: %s != %s" % (tarball_md5, self.header.md5))
        return tarball

    def decode_to(self, tarpath):
        """
        Decode and write tarball to *path*.
        """
        with open(tarpath, 'w') as f:
            f.buffer.write(self.decode())


def readwin2(fh):
    """
    From filehandle *fh*, yield joined lines 1+2, then 2+3,
    etc.  Whitespace is stripped before joining.
    """
    a = fh.readline()
    if not a:
        return
    while True:
        b = fh.readline()
        if not b:
            return
        out = a.rstrip() + b.rstrip()
        a, b = b, None
        yield out


def main(args):

    LOG_DEBUG(args)
    try:
        source, target = args
    except ValueError:
        raise UsageError()

    raw_headers = set()
    raw_chunks = set()

    with open(source) as f:
        for jline in readwin2(f):
            for m in re.findall(RE_HEADER, jline):
                raw_headers.add(m)
            for m in re.findall(RE_CHUNK, jline):
                raw_chunks.add(m)

    if not raw_headers:
        LOG_WARN("no headers found")
        raise IbdmpDecodeError()
    LOG_DEBUG("raw headers found: %d" % len(raw_headers))

    if not raw_chunks:
        LOG_WARN("no chunks found")
        raise IbdmpDecodeError()
    LOG_DEBUG("raw chunks found: %d" % len(raw_chunks))

    header = Header.from_rawN(raw_headers)
    ccounter = ChunkCounter(header, raw_chunks)
    ccounter.decode_to(target)


if __name__ == '__main__':
    try:
        main(sys.argv[1:])
    except UsageError:
        USAGE()
    except IbdmpDecodeError:
        sys.exit(3)
