sgallagh / centos / centpkg

Forked from centos/centpkg 3 years ago
Clone
Blob Blame History Raw
# Copyright (c) 2018 - Red Hat Inc.
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.  See http://www.gnu.org/copyleft/gpl.html for
# the full text of the license.


"""Interact with the Red Hat lookaside cache

We need to override the pyrpkg.lookasidecache module to handle our custom
download path.
"""

import io
import os
import pycurl
import sys

from pyrpkg.errors import InvalidHashType, UploadError, LayoutError
from pyrpkg.lookaside import CGILookasideCache
from pyrpkg.layout.layouts import DistGitLayout

from . import utils


def is_dist_git(folder):
    """
    Indicates if a folder is using a dist-git layout.

    Parameters
    ----------
    folder: str
        The directory to inspect.

    Returns
    -------
    bool
        A bool flag indicating if `folder` is using
        a dist-git layout format.
    """
    result = False

    try:
        DistGitLayout.from_path(folder)
        result = True
    except LayoutError:
        result = False
    finally:
        return result


class StreamLookasideCache(CGILookasideCache):
    """
    CentosStream lookaside specialized class.

    It inherits most of its behavior from `pyrpkg.lookasideCGILookasideCache`.
    """

    def __init__(
        self, hashtype, download_url, upload_url, client_cert=None, ca_cert=None
    ):
        super(StreamLookasideCache, self).__init__(
            hashtype, download_url, upload_url, client_cert=client_cert, ca_cert=ca_cert
        )

    def get_download_url(self, name, filename, hash, hashtype=None, **kwargs):
        _name = utils.get_repo_name(name) if is_dist_git(os.getcwd()) else name

        return super(StreamLookasideCache, self).get_download_url(
            _name, filename, hash, hashtype=hashtype, **kwargs
        )

    def remote_file_exists(self, name, filename, hashstr):
        """
        Check if a remote file exists.

        This method inherits the behavior of its parent class from pyrpkg.

        It uses the internal `utils.get_repo_name` method to parse the name in case
        it is a scm url.

        Parameters
        ----------
        name: str
            The repository name and org.

        filename: str
            The filename (something.tar.gz).

        hash:
            The hash string for the file.

        Returns
        -------
        bool
            A boolean value to inditicate if the file exists.
        """
        _name = utils.get_repo_name(name) if is_dist_git(os.getcwd()) else name

        try:
            status = super(StreamLookasideCache, self).remote_file_exists(
                _name, filename, hashstr
            )
        except UploadError as e:
            self.log.error("Error checking for %s at %s" % (filename, self.upload_url))
            self.log.error(e)
            raise SystemExit(1)

        return status

    def upload(self, name, filename, hashstr, offline=False):
        """
        Uploads a file to lookaside cache.

        This method inherits the behavior of its parent class from pyrpkg.

        It uses the internal `utils.get_repo_name` method to parse the name in case
        it is a scm url.

        Parameters
        ----------
        name: str
            The repository name and org.

        filename: str
            The filename (something.tar.gz).

        hash:
            The hash string for the file.

        Raises
        ------
        pyrpkg.errors.rpkgError
            Raises specialized classes that inherits from pyrpkg base errors.

        Returns
        -------
        None
            Does not return anything
        """
        _name = utils.get_repo_name(name) if is_dist_git(os.getcwd()) else name

        return super(StreamLookasideCache, self).upload(_name, filename, hashstr)

    def download(self, name, filename, hashstr, outfile, hashtype=None, **kwargs):
        """
        Downloads a file from lookaside cache to the local filesystem.

        This method inherits the behavior of its parent class from pyrpkg.

        It uses the internal `utils.get_repo_name` method to parse the name in case
        it is a scm url.

        Parameters
        ----------
        name: str
            The repository name and org.

        filename: str
            The filename (something.tar.gz).

        hash: str
            The hash string for the file.

        outfile: str


        Raises
        ------
        pyrpkg.errors.rpkgError
            Raises specialized implementations of  `yrpkg.errors.rpkgError`.

        Returns
        -------
        None
            Does not return anything
        """
        _name = utils.get_repo_name(name) if is_dist_git(os.getcwd()) else name

        return super(StreamLookasideCache, self).download(
            _name, filename, hashstr, outfile, hashtype=hashtype, **kwargs
        )


class CLLookasideCache(CGILookasideCache):
    """
    Centos Linux lookaside specialized class.

    It inherits most of its behavior from `pyrpkg.lookasideCGILookasideCache`.
    """

    def __init__(self, hashtype, download_url, upload_url, name, branch):
        super(CLLookasideCache, self).__init__(
            hashtype, download_url, upload_url, name, branch
        )
        self.name = name
        self.branch = branch

    def get_download_url(self, name, filename, hash, hashtype=None, **kwargs):
        self.download_path = "%(name)s/%(branch)s/%(hash)s"
        if "/" in name:
            real_name = name.split("/")[-1]
        else:
            real_name = name
        path_dict = {
            "name": real_name,
            "filename": filename,
            "branch": self.branch,
            "hash": hash,
            "hashtype": hashtype,
        }
        path = self.download_path % path_dict
        return os.path.join(self.download_url, path)


class SIGLookasideCache(CGILookasideCache):
    """
    Centos SIG lookaside specialized class.

    It inherits most of its behavior from `pyrpkg.lookasideCGILookasideCache`.
    """

    def __init__(
        self,
        hashtype,
        download_url,
        upload_url,
        name,
        branch,
        client_cert=None,
        ca_cert=None,
    ):
        super(SIGLookasideCache, self).__init__(
            hashtype, download_url, upload_url, client_cert=client_cert, ca_cert=ca_cert
        )

        self.name = name
        self.branch = branch

    def get_download_url(self, name, filename, hash, hashtype=None, **kwargs):
        download_path = "%(name)s/%(branch)s/%(hash)s"
        if "/" in name:
            real_name = name.split("/")[-1]
        else:
            real_name = name
        path_dict = {
            "name": real_name,
            "filename": filename,
            "branch": self.branch,
            "hash": hash,
            "hashtype": hashtype,
        }
        path = download_path % path_dict
        return os.path.join(self.download_url, path)

    def remote_file_exists(self, name, filename, hash):
        """Verify whether a file exists on the lookaside cache

        :param str name: The name of the module. (usually the name of the
            SRPM). This can include the namespace as well (depending on what
            the server side expects).
        :param str filename: The name of the file to check for.
        :param str hash: The known good hash of the file.
        """

        # RHEL 7 ships pycurl that does not accept unicode. When given unicode
        # type it would explode with "unsupported second type in tuple". Let's
        # convert to str just to be sure.
        # https://bugzilla.redhat.com/show_bug.cgi?id=1241059
        _name = utils.get_repo_name(name) if is_dist_git(os.getcwd()) else name

        post_data = [
            ("name", _name),
            ("%ssum" % self.hashtype, hash),
            ("filename", filename),
        ]

        with io.BytesIO() as buf:
            c = pycurl.Curl()
            c.setopt(pycurl.URL, self.upload_url)
            c.setopt(pycurl.WRITEFUNCTION, buf.write)
            c.setopt(pycurl.HTTPPOST, post_data)

            if self.client_cert is not None:
                if os.path.exists(self.client_cert):
                    c.setopt(pycurl.SSLCERT, self.client_cert)
                else:
                    self.log.warning("Missing certificate: %s" % self.client_cert)

            if self.ca_cert is not None:
                if os.path.exists(self.ca_cert):
                    c.setopt(pycurl.CAINFO, self.ca_cert)
                else:
                    self.log.warning("Missing certificate: %s", self.ca_cert)

            c.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_GSSNEGOTIATE)
            c.setopt(pycurl.USERPWD, ":")

            try:
                c.perform()
                status = c.getinfo(pycurl.RESPONSE_CODE)

            except Exception as e:
                raise UploadError(e)

            finally:
                c.close()

            output = buf.getvalue().strip()

        if status != 200:
            self.raise_upload_error(status)

        # Lookaside CGI script returns these strings depending on whether
        # or not the file exists:
        if output == b"Available":
            return True

        if output == b"Missing":
            return False

        # Something unexpected happened
        self.log.debug(output)
        raise UploadError("Error checking for %s at %s" % (filename, self.upload_url))

    def upload(self, name, filepath, hash):
        """Upload a source file

        :param str name: The name of the module. (usually the name of the SRPM)
            This can include the namespace as well (depending on what the
            server side expects).
        :param str filepath: The full path to the file to upload.
        :param str hash: The known good hash of the file.
        """
        filename = os.path.basename(filepath)

        if self.remote_file_exists(name, filename, hash):
            self.log.info("File already uploaded: %s", filepath)
            return

        self.log.info("Uploading: %s", filepath)
        post_data = [
            ("name", name),
            ("%ssum" % self.hashtype, hash),
            ("file", (pycurl.FORM_FILE, filepath)),
        ]

        with io.BytesIO() as buf:
            c = pycurl.Curl()
            c.setopt(pycurl.URL, self.upload_url)
            c.setopt(pycurl.NOPROGRESS, False)
            c.setopt(pycurl.PROGRESSFUNCTION, self.print_progress)
            c.setopt(pycurl.WRITEFUNCTION, buf.write)
            c.setopt(pycurl.HTTPPOST, post_data)

            if self.client_cert is not None:
                if os.path.exists(self.client_cert):
                    c.setopt(pycurl.SSLCERT, self.client_cert)
                else:
                    self.log.warning("Missing certificate: %s", self.client_cert)

            if self.ca_cert is not None:
                if os.path.exists(self.ca_cert):
                    c.setopt(pycurl.CAINFO, self.ca_cert)
                else:
                    self.log.warning("Missing certificate: %s", self.ca_cert)

            c.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_GSSNEGOTIATE)
            c.setopt(pycurl.USERPWD, ":")

            try:
                c.perform()
                status = c.getinfo(pycurl.RESPONSE_CODE)

            except Exception as e:
                raise UploadError(e)

            finally:
                c.close()

            output = buf.getvalue().strip()

        # Get back a new line, after displaying the download progress
        sys.stdout.write("\n")
        sys.stdout.flush()

        if status != 200:
            self.raise_upload_error(status)

        if output:
            self.log.debug(output)