#
# Copyright (c) 2023 Red Hat, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#           http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import requests
import time
import os
import stat
from jwcrypto import jwt
from datetime import datetime, timedelta
from requests import HTTPError
from redhat_support_lib.utils import confighelper
from redhat_support_lib.infrastructure.s3_uploader import check_for_proxy, check_ssl_params

__author__ = 'Pranita Ghole pghole@redhat.com'
__author__ = 'Swaraj Pande spande@redhat.com'
# __author__ = 'Vitaliy Dymna'

DEVICE_AUTH_CLIENT_ID = "redhat-support-tool"
GRANT_TYPE_DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code"
OIDC_TOKENS_FILE = ".token.json"
DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"

logger = logging.getLogger(__name__)
config = confighelper.get_config_helper()

tool_dir = str(os.path.join(os.path.expanduser('~'), '.redhat-support-tool/'))


def try_read_refresh_token(base_path):
    """
    Try to read locally stored refresh token

    :param base_path: Local path where OIDC token is stored in string format

    :return: OIDC refresh token if found otherwise None
    """

    tokens_file_path = os.path.join(base_path, OIDC_TOKENS_FILE)
    if os.path.exists(tokens_file_path):
        logger.info("Retrieving the token information from local {} file".format(OIDC_TOKENS_FILE))
        with open(tokens_file_path, "r") as fileobj:
            try:
                read_in = fileobj.read()
                if read_in:
                    token_data = json.loads(read_in)
                else:
                    logger.info("The {} file was empty.".format(OIDC_TOKENS_FILE))
                    return None
            except json.decoder.JSONDecodeError:
                logger.debug("There was a json exception thrown loading {}.".format(OIDC_TOKENS_FILE))
                return None

            if token_data:
                refresh_token = token_data.get("refresh_token")
                dt_str = token_data.get("refresh_expires_at")
                username = token_data.get("username")
                refresh_expires_at = None
                if dt_str:
                    refresh_expires_at = datetime.strptime(dt_str, DATETIME_FORMAT)
                if refresh_token and refresh_expires_at and refresh_expires_at - timedelta(seconds=300) > datetime.utcnow():
                    if username == config.username:
                        return refresh_token
                    else:
                        logger.debug("Locally cached token does not belong to the currently configured user.")
        logger.debug("Locally cached token has expired or is incorrectly formatted."
                     "Please remove the file {}".format(OIDC_TOKENS_FILE))
    else:
        logger.info("Cannot find {} file using {} path".format(OIDC_TOKENS_FILE, base_path))
        return None


def parse_token(token):
    try:
        jwt_token = jwt.JWT(key=None, jwt=token)
        payload = jwt_token.token.objects.get('payload', None)
        if payload:
            return json.loads(payload)
    except:
        logger.error("Cannot retrieve the logged in user from the JWT token received.")
        raise Exception("Cannot retrieve the logged in user from the JWT token received.")


class AuthClass:
    """
    Authorisation Class. It supports device authorisation flow and client credentials grant.
    """

    def __init__(self, kwargs):
        self.context = kwargs
        self._access_token = None
        self._access_expires_at = None
        self._refresh_token = None
        self._refresh_expires_at = None
        self._user_verification_url = None
        self.__device_code = None
        self.logged_in_user = None

        self.proxies = check_for_proxy()
        self.verify = check_ssl_params()
        self._save_token = config.save_token

        self.client_identifier_url = kwargs.get('client_identifier_url')
        self.token_endpoint = kwargs.get('token_endpoint')

        self.client_id = kwargs.get('client_id')
        self.client_secret = kwargs.get('client_secret')

        if self.client_id and self.client_secret:
            self.grant_type = "client_credentials"
            self._use_client_credentials_grant()
        else:
            self.grant_type = "device_auth"
            self._use_device_code_grant()

    def _use_client_credentials_grant(self):
        """
        Use the client credentials grant method to generate the tokens.
        Service accounts have the client_id and client_secret associated with it, which will be used.
        """

        logger.info("Using the Service Account for authorization {} ".format(self.client_id))
        headers = {
            "Content-Type": "application/x-www-form-urlencoded"
        }
        payload = "grant_type=client_credentials&client_id={}&client_secret={}".format(self.client_id,
                                                                                       self.client_secret)

        try:
            response = requests.post(self.token_endpoint, data=payload, headers=headers, proxies=self.proxies, verify=self.verify)
            response.raise_for_status()
            self._access_expires_at = datetime.utcnow() + timedelta(seconds=response.json().get("expires_in"))
            self._access_token = response.json().get("access_token")
        except HTTPError as e:
            logger.error("Error while fetching token using Service Account {} {}".format(response.status_code, response.text))
            raise Exception("Error while fetching token using Service Account {} {}".format(response.status_code, response.text))
        except Exception as e:
            raise e

    def _use_device_code_grant(self):
        """
        Start the device auth flow. First read the tokens from the local file(if user has configured to save it). If the
        token is not present or is expired, start the device auth flow.
        """

        stored_refresh_token = try_read_refresh_token(tool_dir)
        if not stored_refresh_token:
            self._request_device_code()
            print("Please visit the following URL in the browser to login and authorize - {} ".format(self._verification_uri_complete))
            self.poll_for_auth_completion()
        else:
            self._use_refresh_token_grant(stored_refresh_token)

    def _request_device_code(self):
        """
        Initialize new Device Authorization Grant attempt by requesting a new device code.
        """

        data = "client_id={}&scope=openid".format(DEVICE_AUTH_CLIENT_ID)
        headers = {'content-type': 'application/x-www-form-urlencoded'}
        try:
            res = requests.post(self.client_identifier_url,
                                data=data,
                                headers=headers,
                                proxies=self.proxies,
                                verify=self.verify)
            res.raise_for_status()
            response = res.json()
            self._user_code = response.get("user_code")
            self._verification_uri = response.get("verification_uri")
            self._interval = response.get("interval")
            self.__device_code = response.get("device_code")
            self._verification_uri_complete = response.get("verification_uri_complete")
        except HTTPError as e:
            logger.error("There is a error while requesting the device code {}".format(e))
            raise e
        except Exception as e:
            raise e

    def poll_for_auth_completion(self):
        """
        Continuously poll OIDC token endpoint until the user is successfully authenticated or an error occurs.
        """

        token_data = {'grant_type': GRANT_TYPE_DEVICE_CODE,
                      'client_id': DEVICE_AUTH_CLIENT_ID,
                      'device_code': self.__device_code,
                      'scope': 'openid'}

        while self._access_token is None:
            time.sleep(self._interval)
            try:
                check_auth_completion = requests.post(self.token_endpoint, data=token_data, proxies=self.proxies,
                                                      verify=self.verify)

                status_code = check_auth_completion.status_code

                if status_code == 200:
                    logger.info("The SSO authentication is successful")
                    self._set_token_data(check_auth_completion.json())
                if status_code not in [200, 400]:
                    raise Exception(status_code, check_auth_completion.text)
                if status_code == 400 and check_auth_completion.json()['error'] not in ("authorization_pending",
                                                                                        "slow_down"):
                    raise Exception(status_code, check_auth_completion.text)
            except Exception as e:
                logger.error("Exception occurred while polling for authentication")
                raise e

    def _set_token_data(self, token_data):
        """
        Set the class attributes as per the input token_data received.
        :param token_data: Token data containing access_token, refresh_token and their expiry etc.
        """

        self._access_token = token_data.get("access_token")
        self._access_expires_at = datetime.utcnow() + timedelta(seconds=token_data.get("expires_in"))
        self._refresh_token = token_data.get("refresh_token")
        self._id_token = token_data.get("id_token")
        refresh_expires_in = token_data.get("refresh_expires_in")

        if refresh_expires_in == 0:
            self._refresh_expires_at = datetime.max
        else:
            self._refresh_expires_at = datetime.utcnow() + timedelta(seconds=refresh_expires_in)

        if not self._id_token:
            raise Exception("The token information received from the auth cannot retrieve the logged in user details !")

        if self._id_token:
            parsed_id_token = parse_token(self._id_token)
            if parsed_id_token:
                self.logged_in_user = parsed_id_token.get("preferred_username")
            
            if self.logged_in_user and self.logged_in_user != config.username:
                msg = ("The configured user in the tool seems to be different from the user logged in the browser. "
                       "Please verify and configure correctly !")
                print(msg)
                logger.error(msg)
                raise Exception(msg)

        if self._save_token:
            self.persist_refresh_token(tool_dir)

    def get_access_token(self):
        """
        Get the valid access_token at any given time.
        :return: Access_token
        :rtype: string
        """

        if self.is_access_token_valid():
            return self._access_token

        if self.grant_type == "client_credentials":
            self._use_client_credentials_grant()
            return self._access_token

        elif self.grant_type == "device_auth":
            if self.is_refresh_token_valid():
                self._use_refresh_token_grant()
                return self._access_token
            else:
                self.request_new_device_code()
                return self._access_token

    def is_access_token_valid(self):
        """
        Check the validity of access_token. We are considering it invalid 180 sec. prior to it's exact expiry time.
        :return: True/False
        """
        return self._access_token and self._access_expires_at and self._access_expires_at - timedelta(seconds=180) > datetime.utcnow()

    def is_refresh_token_valid(self):
        """
        Check the validity of refresh_token. We are considering it invalid 180 sec. prior to it's exact expiry time.

        :return: True/False
        """
        return self._refresh_token and self._refresh_expires_at and self._refresh_expires_at - timedelta(seconds=180) > datetime.utcnow()

    def _use_refresh_token_grant(self, refresh_token=None):
        """
        Fetch the new access_token and refresh_token using the existing refresh_token.
        :param refresh_token: optional param for refresh_token
        """
        refresh_token_data = {'client_id': DEVICE_AUTH_CLIENT_ID,
                              'grant_type': 'refresh_token',
                              'refresh_token': self._refresh_token if not refresh_token else refresh_token,
                              'scope': 'openid'}

        refresh_token_res = requests.post(self.token_endpoint, data=refresh_token_data, proxies=self.proxies, verify=self.verify)

        if refresh_token_res.status_code == 200:
            self._set_token_data(refresh_token_res.json())

        elif refresh_token_res.status_code == 400 and 'invalid' in refresh_token_res.json()['error']:
            logger.warning("Problem while fetching the new tokens from refresh token grant - {} {}."
                           " New Device code will be requested !". format(refresh_token_res.status_code,
                                                                          refresh_token_res.json()['error']))

            os.remove(os.path.join(tool_dir, OIDC_TOKENS_FILE))
            self.request_new_device_code()
        else:
            raise Exception("Something went wrong while using the Refresh token grant for fetching tokens - {} "
                            "{}".format(refresh_token_res.status_code, refresh_token_res.json()['error']))

    def request_new_device_code(self):
        """
        Initialize new Device Authorization Grant attempt by requesting a new device code.
        """
        self._use_device_code_grant()

    def persist_refresh_token(self, base_path):
        """
        Persist current refresh token in a local file.

        :param base_path: Local path in string to a directory where token should be stored

        :return: True if refresh token was successfully persisted, otherwise False
        """

        if self.is_refresh_token_valid():
            if not os.path.exists(base_path):
                os.mkdir(base_path)
            token_data = {
                "username": self.logged_in_user,
                "refresh_token": self._refresh_token,
                "refresh_expires_at": self._refresh_expires_at.strftime(DATETIME_FORMAT)
            }
            tokens_file_path = os.path.join(base_path, OIDC_TOKENS_FILE)
            with open(tokens_file_path, "w") as fileobj:
                json.dump(token_data, fileobj)
            os.chmod(tokens_file_path, stat.S_IRUSR | stat.S_IWUSR)
            logger.info("The new refresh token was successfully saved to {}".format(tokens_file_path))
            return True
        else:
            logger.info("Cannot save invalid refresh token !")
            return False
