// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.Win32.SafeHandles;
using System.Diagnostics;
using System.Net.Security;
using System.Runtime.InteropServices;
using System.Security;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Security.Principal;
using System.Threading;

namespace System.Net
{
    internal static partial class CertificateValidationPal
    {
        private static readonly object s_syncObject = new object();

        private static volatile X509Store s_myCertStoreEx;
        private static volatile X509Store s_myMachineCertStoreEx;

        internal static SslPolicyErrors VerifyCertificateProperties(
            X509Chain chain,
            X509Certificate2 remoteCertificate,
            bool checkCertName,
            bool isServer,
            string hostName)
        {
            SslPolicyErrors sslPolicyErrors = SslPolicyErrors.None;

            if (!chain.Build(remoteCertificate)       // Build failed on handle or on policy.
                && chain.SafeHandle.DangerousGetHandle() == IntPtr.Zero)   // Build failed to generate a valid handle.
            {
                throw new CryptographicException(Marshal.GetLastWin32Error());
            }

            if (checkCertName)
            {
                unsafe
                {
                    uint status = 0;

                    var eppStruct = new Interop.Crypt32.SSL_EXTRA_CERT_CHAIN_POLICY_PARA()
                    {
                        cbSize = (uint)Marshal.SizeOf<Interop.Crypt32.SSL_EXTRA_CERT_CHAIN_POLICY_PARA>(),
                        // Authenticate the remote party: (e.g. when operating in server mode, authenticate the client).
                        dwAuthType = isServer ? Interop.Crypt32.AuthType.AUTHTYPE_CLIENT : Interop.Crypt32.AuthType.AUTHTYPE_SERVER,
                        fdwChecks = 0,
                        pwszServerName = null
                    };

                    var cppStruct = new Interop.Crypt32.CERT_CHAIN_POLICY_PARA()
                    {
                        cbSize = (uint)Marshal.SizeOf<Interop.Crypt32.CERT_CHAIN_POLICY_PARA>(),
                        dwFlags = 0,
                        pvExtraPolicyPara = &eppStruct
                    };

                    fixed (char* namePtr = hostName)
                    {
                        eppStruct.pwszServerName = namePtr;
                        cppStruct.dwFlags |=
                            (Interop.Crypt32.CertChainPolicyIgnoreFlags.CERT_CHAIN_POLICY_IGNORE_ALL &
                             ~Interop.Crypt32.CertChainPolicyIgnoreFlags.CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG);

                        SafeX509ChainHandle chainContext = chain.SafeHandle;
                        status = Verify(chainContext, ref cppStruct);
                        if (status == Interop.Crypt32.CertChainPolicyErrors.CERT_E_CN_NO_MATCH)
                        {
                            sslPolicyErrors |= SslPolicyErrors.RemoteCertificateNameMismatch;
                        }
                    }
                }
            }

            X509ChainStatus[] chainStatusArray = chain.ChainStatus;
            if (chainStatusArray != null && chainStatusArray.Length != 0)
            {
                sslPolicyErrors |= SslPolicyErrors.RemoteCertificateChainErrors;
            }

            return sslPolicyErrors;
        }

        //
        // Extracts a remote certificate upon request.
        //
        internal static X509Certificate2 GetRemoteCertificate(SafeDeleteContext securityContext, out X509Certificate2Collection remoteCertificateCollection)
        {
            remoteCertificateCollection = null;

            if (securityContext == null)
            {
                return null;
            }

            if (GlobalLog.IsEnabled)
            {
                GlobalLog.Enter("CertificateValidationPal.Windows SecureChannel#" + LoggingHash.HashString(securityContext) + "::GetRemoteCertificate()");
            }

            X509Certificate2 result = null;
            SafeFreeCertContext remoteContext = null;
            try
            {
                remoteContext = SSPIWrapper.QueryContextAttributes(GlobalSSPI.SSPISecureChannel, securityContext, Interop.SspiCli.ContextAttribute.RemoteCertificate) as SafeFreeCertContext;
                if (remoteContext != null && !remoteContext.IsInvalid)
                {
                    result = new X509Certificate2(remoteContext.DangerousGetHandle());
                }
            }
            finally
            {
                if (remoteContext != null && !remoteContext.IsInvalid)
                {
                    remoteCertificateCollection = UnmanagedCertificateContext.GetRemoteCertificatesFromStoreContext(remoteContext);

                    remoteContext.Dispose();
                }
            }

            if (SecurityEventSource.Log.IsEnabled())
            {
                SecurityEventSource.Log.RemoteCertificate(result == null ? "null" : result.ToString(true));
            }

            if (GlobalLog.IsEnabled)
            {
                GlobalLog.Leave("CertificateValidationPal.Windows SecureChannel#" + LoggingHash.HashString(securityContext) + "::GetRemoteCertificate()", (result == null ? "null" : result.Subject));
            }

            return result;
        }

        //
        // Used only by client SSL code, never returns null.
        //
        internal static string[] GetRequestCertificateAuthorities(SafeDeleteContext securityContext)
        {
            Interop.SspiCli.IssuerListInfoEx issuerList =
                (Interop.SspiCli.IssuerListInfoEx)SSPIWrapper.QueryContextAttributes(
                    GlobalSSPI.SSPISecureChannel,
                    securityContext,
                    Interop.SspiCli.ContextAttribute.IssuerListInfoEx);

            string[] issuers = Array.Empty<string>();

            try
            {
                if (issuerList.cIssuers > 0)
                {
                    unsafe
                    {
                        uint count = issuerList.cIssuers;
                        issuers = new string[issuerList.cIssuers];
                        Interop.SspiCli._CERT_CHAIN_ELEMENT* pIL = (Interop.SspiCli._CERT_CHAIN_ELEMENT*)issuerList.aIssuers.DangerousGetHandle();
                        for (int i = 0; i < count; ++i)
                        {
                            Interop.SspiCli._CERT_CHAIN_ELEMENT* pIL2 = pIL + i;
                            if (pIL2->cbSize <= 0)
                            {
                                if (GlobalLog.IsEnabled)
                                {
                                    GlobalLog.Assert("SecureChannel::GetIssuers()", "Interop.SspiCli._CERT_CHAIN_ELEMENT size is not positive: " + pIL2->cbSize.ToString());
                                }

                                Debug.Fail("SecureChannel::GetIssuers()", "Interop.SspiCli._CERT_CHAIN_ELEMENT size is not positive: " + pIL2->cbSize.ToString());
                            }

                            if (pIL2->cbSize > 0)
                            {
                                uint size = pIL2->cbSize;
                                byte* ptr = (byte*)(pIL2->pCertContext);
                                byte[] x = new byte[size];
                                for (int j = 0; j < size; j++)
                                {
                                    x[j] = *(ptr + j);
                                }

                                X500DistinguishedName x500DistinguishedName = new X500DistinguishedName(x);
                                issuers[i] = x500DistinguishedName.Name;
                                if (GlobalLog.IsEnabled)
                                {
                                    GlobalLog.Print("SecureChannel#" + LoggingHash.HashString(securityContext) + "::GetIssuers() IssuerListEx[" + i + "]:" + issuers[i]);
                                }
                            }
                        }
                    }
                }
            }
            finally
            {
                if (issuerList.aIssuers != null)
                {
                    issuerList.aIssuers.Dispose();
                }
            }

            return issuers;
        }

        //
        // Security: We temporarily reset thread token to open the cert store under process account.
        //
        internal static X509Store EnsureStoreOpened(bool isMachineStore)
        {
            X509Store store = isMachineStore ? s_myMachineCertStoreEx : s_myCertStoreEx;

            // TODO #3862 Investigate if this can be switched to either the static or Lazy<T> patterns.
            if (store == null)
            {
                lock (s_syncObject)
                {
                    store = isMachineStore ? s_myMachineCertStoreEx : s_myCertStoreEx;
                    if (store == null)
                    {
                        // NOTE: that if this call fails we won't keep track and the next time we enter we will try to open the store again.
                        StoreLocation storeLocation = isMachineStore ? StoreLocation.LocalMachine : StoreLocation.CurrentUser;
                        store = new X509Store(StoreName.My, storeLocation);
                        try
                        {
                            // For app-compat We want to ensure the store is opened under the **process** account.
                            try
                            {
                                WindowsIdentity.RunImpersonated(SafeAccessTokenHandle.InvalidHandle, () =>
                                {
                                    store.Open(OpenFlags.ReadOnly | OpenFlags.OpenExistingOnly);
                                    if (GlobalLog.IsEnabled)
                                    {
                                        GlobalLog.Print("SecureChannel::EnsureStoreOpened() storeLocation:" + storeLocation + " returned store:" + store.GetHashCode().ToString("x"));
                                    }
                                });
                            }
                            catch
                            {
                                throw;
                            }

                            if (isMachineStore)
                            {
                                s_myMachineCertStoreEx = store;
                            }
                            else
                            {
                                s_myCertStoreEx = store;
                            }

                            return store;
                        }
                        catch (Exception exception)
                        {
                            if (exception is CryptographicException || exception is SecurityException)
                            {
                                if (GlobalLog.IsEnabled)
                                {
                                    GlobalLog.Assert("SecureChannel::EnsureStoreOpened()", "Failed to open cert store, location:" + storeLocation + " exception:" + exception);
                                }

                                Debug.Fail("SecureChannel::EnsureStoreOpened()", "Failed to open cert store, location:" + storeLocation + " exception:" + exception);
                                return null;
                            }

                            if (NetEventSource.Log.IsEnabled())
                            {
                                NetEventSource.PrintError(NetEventSource.ComponentType.Security, SR.Format(SR.net_log_open_store_failed, storeLocation, exception));
                            }

                            throw;
                        }
                    }
                }
            }

            return store;
        }

        private static uint Verify(SafeX509ChainHandle chainContext, ref Interop.Crypt32.CERT_CHAIN_POLICY_PARA cpp)
        {
            if (GlobalLog.IsEnabled)
            {
                GlobalLog.Enter("SecureChannel::VerifyChainPolicy", "chainContext=" + chainContext + ", options=" + String.Format("0x{0:x}", cpp.dwFlags));
            }

            var status = new Interop.Crypt32.CERT_CHAIN_POLICY_STATUS();
            status.cbSize = (uint)Marshal.SizeOf<Interop.Crypt32.CERT_CHAIN_POLICY_STATUS>();

            bool errorCode =
                Interop.Crypt32.CertVerifyCertificateChainPolicy(
                    (IntPtr)Interop.Crypt32.CertChainPolicy.CERT_CHAIN_POLICY_SSL,
                    chainContext,
                    ref cpp,
                    ref status);

            if (GlobalLog.IsEnabled)
            {
                GlobalLog.Print("SecureChannel::VerifyChainPolicy() CertVerifyCertificateChainPolicy returned: " + errorCode);
#if TRACE_VERBOSE
                GlobalLog.Print("SecureChannel::VerifyChainPolicy() error code: " + status.dwError + String.Format(" [0x{0:x8}", status.dwError) + " " + Interop.MapSecurityStatus(status.dwError) + "]");
#endif
                GlobalLog.Leave("SecureChannel::VerifyChainPolicy", status.dwError.ToString());
            }

            return status.dwError;
        }
    }
}
