#! /usr/bin/python -u
#
# The katello Acquire Method
#
# Author: Matthias Dellweg <dellweg [at] atix.de>
# Licence: GPLv2
# Copyright (c) 2017 ATIX AG
#
# This is derived and partly borrowed from:
#
# > The Spacewalk Acquire Method
#
# > Author:  Simon Lukasik <xlukas08 [at] stud.fit.vutbr.cz>
# > Date:    2011-01-01
# > License: GPLv2
#
# > Copyright (c) 1999--2012 Red Hat, Inc.
#
# This software is licensed to you under the GNU General Public License,
# version 2 (GPLv2). There is NO WARRANTY for this software, express or
# implied, including the implied warranties of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. You should have received a copy of GPLv2
# along with this software; if not, see
# http://www.gnu.org/licenses/old-licenses/gpl-2.0.txt.


from __future__ import print_function

import sys
import re
import hashlib

import requests


class pkg_acquire_method:
    """
    This is slightly modified python variant of apt-pkg/acquire-method.
    It is a skeleton class that implements only very basic of apt methods
    functionality.
    """
    __eof = False

    def __init__(self):
        print("100 Capabilities\nVersion: 1.0\nSingle-Instance: true\n")

    def __get_next_msg(self):
        """
        Apt uses for communication with its methods the text protocol similar
        to http. This function parses the protocol messages from stdin.
        """
        if self.__eof:
            return None
        result = {}
        line = sys.stdin.readline()
        while line == '\n':
            line = sys.stdin.readline()
        if not line:
            self.__eof = True
            return None
        s = line.split(" ", 1)
        result['_number'] = int(s[0])
        result['_text'] = s[1].strip()

        while not self.__eof:
            line = sys.stdin.readline()
            if not line:
                self.__eof = True
                return result
            if line == '\n':
                return result
            s = line.split(":", 1)
            result[s[0]] = s[1].strip()

    def __dict2msg(self, msg):
        """Convert dictionary to http like message"""
        result = ""
        for item in msg.keys():
            if msg[item] is not None:
                result += item + ": " + msg[item] + "\n"
        return result

    def status(self, **kwargs):
        print("102 Status\n%s" % self.__dict2msg(kwargs))

    def uri_start(self, msg):
        print("200 URI Start\n%s" % self.__dict2msg(msg))

    def uri_done(self, msg):
        print("201 URI Done\n%s" % self.__dict2msg(msg))

    def uri_failure(self, msg):
        print("400 URI Failure\n%s" % self.__dict2msg(msg))

    def run(self):
        """Loop through requests on stdin"""
        while True:
            msg = self.__get_next_msg()
            if msg is None:
                return 0
            if msg['_number'] == 600:
                try:
                    self.fetch(msg)
                except Exception as e:
                    self.fail(e.__class__.__name__ + ": " + str(e))
            else:
                return 100


class katello_method(pkg_acquire_method):
    """
    Katello acquire method
    """
    current_url = None
    http_headers = None
    not_registered_msg = 'This system is not registered'

    def fail(self, message=not_registered_msg):
        self.uri_failure({'URI': self.uri,
                          'Message': message})

    def __parse_uri(self, uri):
        self.entitlement = None
        result = re.match(r"^katello://((?P<entitlement>.*?)@)?(?P<url>.*)$", uri)
        if not result:
            raise Exception('Protocol mismatch')
        self.entitlement = result.group('entitlement')
        if self.entitlement:
            protocol = 'https://'
            self.ssl_cacert = '/etc/rhsm/ca/katello-server-ca.pem'
            self.ssl_cert = '/etc/pki/entitlement/{}.pem'.format(self.entitlement)
            self.ssl_key = '/etc/pki/entitlement/{}-key.pem'.format(self.entitlement)
        else:
            protocol = 'http://'
        return protocol + result.group('url')

    def fetch(self, msg):
        """
        Fetch the content from pulp server to the file.

        Acording to the apt protocol msg must contain: 'URI' and 'Filename'.
        Other possible keys are: 'Last-Modified', 'Index-File', 'Fail-Ignore'
        """
        self.uri = msg['URI']
        self.url = self.__parse_uri(self.uri)
        self.filename = msg['Filename']

        self.status(URI=self.uri, Message='Waiting for headers')

        ssl_options = {}
        if self.entitlement:
            ssl_options['verify'] = self.ssl_cacert
            ssl_options['cert'] = (self.ssl_cert, self.ssl_key)

        req = requests.get(self.url, stream=True, **ssl_options)
        try:
            if req.status_code != 200:
                self.uri_failure({'URI': self.uri,
                                  'Message': str(req.status_code) + '  ' + req.reason,
                                  'FailReason': 'HttpError' + str(req.status_code)})
                return

            self.uri_start({'URI': self.uri,
                            'Size': req.headers.get('content-length'),
                            'Last-Modified': req.headers.get('last-modified')})

            with open(self.filename, "wb") as f:
                hash_sha256 = hashlib.sha256()
                hash_md5 = hashlib.md5()
                for data in req.iter_content(chunk_size=4096):
                    hash_sha256.update(data)
                    hash_md5.update(data)
                    f.write(data)
        finally:
            req.close()

        self.uri_done({'URI': self.uri,
                       'Filename': self.filename,
                       'Size': req.headers.get('content-length'),
                       'Last-Modified': req.headers.get('last-modified'),
                       'MD5-Hash': hash_md5.hexdigest(),
                       'MD5Sum-Hash': hash_md5.hexdigest(),
                       'SHA256-Hash': hash_sha256.hexdigest()})


if __name__ == '__main__':
    try:
        method = katello_method()
        ret = method.run()
        sys.exit(ret)
    except KeyboardInterrupt:
        pass
