﻿// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using NuGet.Packaging.Core;
using NuGet.ProjectModel;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;

namespace Microsoft.NET.Build.Tasks
{
    internal class ProjectContext
    {
        private readonly LockFile _lockFile;
        private readonly LockFileTarget _lockFileTarget;
        internal HashSet<PackageIdentity> PackagesToBeFiltered { get; set; }

        /// <summary>
        /// A value indicating that this project runs on a shared system-wide framework.
        /// (ex. Microsoft.NETCore.App for .NET Core)
        /// </summary>
        public bool IsFrameworkDependent { get; }

        /// <summary>
        /// A value indicating that this project is portable across operating systems, processor architectures, etc.
        /// </summary>
        /// <remarks>
        /// Returns <c>true</c> for projects running on shared frameworks (<see cref="IsFrameworkDependent" />)
        /// that do not target a specific RID.
        /// </remarks>
        public bool IsPortable => IsFrameworkDependent && string.IsNullOrEmpty(_lockFileTarget.RuntimeIdentifier);

        public LockFileTargetLibrary PlatformLibrary { get; }

        public LockFile LockFile => _lockFile;
        public LockFileTarget LockFileTarget => _lockFileTarget;

        public ProjectContext(LockFile lockFile, LockFileTarget lockFileTarget, LockFileTargetLibrary platformLibrary, bool isFrameworkDependent)
        {
            Debug.Assert(lockFile != null);
            Debug.Assert(lockFileTarget != null);
            if (isFrameworkDependent)
            {
                Debug.Assert(platformLibrary != null);
            }

            _lockFile = lockFile;
            _lockFileTarget = lockFileTarget;

            PlatformLibrary = platformLibrary;
            IsFrameworkDependent = isFrameworkDependent;
        }

        public IEnumerable<LockFileTargetLibrary> GetRuntimeLibraries(IEnumerable<string> excludeFromPublishPackageIds)
        {
            IEnumerable<LockFileTargetLibrary> runtimeLibraries = _lockFileTarget.Libraries;
            Dictionary<string, LockFileTargetLibrary> libraryLookup =
                runtimeLibraries.ToDictionary(e => e.Name, StringComparer.OrdinalIgnoreCase);

            HashSet<string> allExclusionList = new HashSet<string>(StringComparer.OrdinalIgnoreCase);

            if (IsFrameworkDependent)
            {
                allExclusionList.UnionWith(_lockFileTarget.GetPlatformExclusionList(PlatformLibrary, libraryLookup));
            }

            if (excludeFromPublishPackageIds?.Any() == true)
            {
                HashSet<string> excludeFromPublishList =
                    GetExcludeFromPublishList(
                        excludeFromPublishPackageIds,
                        libraryLookup);

                allExclusionList.UnionWith(excludeFromPublishList);
            }

            if (PackagesToBeFiltered != null)
            {
                var filterLookup = new Dictionary<string, HashSet<PackageIdentity>>(StringComparer.OrdinalIgnoreCase);
                foreach (var pkg in PackagesToBeFiltered)
                {
                    HashSet<PackageIdentity> packageinfos;
                    if (filterLookup.TryGetValue(pkg.Id, out packageinfos))
                    {
                        packageinfos.Add(pkg);
                    }
                    else
                    {
                        packageinfos = new HashSet<PackageIdentity>();
                        packageinfos.Add(pkg);
                        filterLookup.Add(pkg.Id, packageinfos);
                    }
                }

                allExclusionList.UnionWith(GetPackagesToBeFiltered(filterLookup, libraryLookup));
            }

            return runtimeLibraries.Filter(allExclusionList).ToArray();
        }

        internal IEnumerable<PackageIdentity> GetTransitiveList(string package)
        {
            LockFileTargetLibrary platformLibrary = _lockFileTarget.GetLibrary(package);
            IEnumerable<LockFileTargetLibrary> runtimeLibraries = _lockFileTarget.Libraries;
            Dictionary<string, LockFileTargetLibrary> libraryLookup =
                runtimeLibraries.ToDictionary(e => e.Name, StringComparer.OrdinalIgnoreCase);

            return  _lockFileTarget.GetTransitivePackagesList(platformLibrary, libraryLookup);
        }

        public IEnumerable<LockFileTargetLibrary> GetCompileLibraries(IEnumerable<string> compileExcludeFromPublishPackageIds)
        {
            IEnumerable<LockFileTargetLibrary> compileLibraries = _lockFileTarget.Libraries;

            if (compileExcludeFromPublishPackageIds?.Any() == true)
            {
                Dictionary<string, LockFileTargetLibrary> libraryLookup =
                    compileLibraries.ToDictionary(e => e.Name, StringComparer.OrdinalIgnoreCase);

                HashSet<string> excludeFromPublishList =
                    GetExcludeFromPublishList(
                        compileExcludeFromPublishPackageIds,
                        libraryLookup);

                compileLibraries = compileLibraries.Filter(excludeFromPublishList);
            }

            return compileLibraries.ToArray();
        }

        public IEnumerable<string> GetTopLevelDependencies()
        {
            Dictionary<string, LockFileTargetLibrary> libraryLookup =
                LockFileTarget.Libraries.ToDictionary(l => l.Name, StringComparer.OrdinalIgnoreCase);

            return LockFile
                .ProjectFileDependencyGroups
                .Where(dg => dg.FrameworkName == string.Empty ||
                             dg.FrameworkName == LockFileTarget.TargetFramework.DotNetFrameworkName)
                .SelectMany(g => g.Dependencies)
                .Select(projectFileDependency =>
                {
                    int separatorIndex = projectFileDependency.IndexOf(' ');
                    string libraryName = separatorIndex > 0 ?
                        projectFileDependency.Substring(0, separatorIndex) :
                        projectFileDependency;

                    if (!string.IsNullOrEmpty(libraryName) && libraryLookup.ContainsKey(libraryName))
                    {
                        return libraryName;
                    }

                    return null;
                })
                .Where(libraryName => libraryName != null)
                .ToArray();
        }

        public HashSet<string> GetExcludeFromPublishList(
            IEnumerable<string> excludeFromPublishPackageIds,
            IDictionary<string, LockFileTargetLibrary> libraryLookup)
        {
            var nonExcludeFromPublishAssets = new HashSet<string>(StringComparer.OrdinalIgnoreCase);

            var nonExcludeFromPublishAssetsToSearch = new Stack<string>();
            var excludeFromPublishAssetsToSearch = new Stack<string>();

            // Start with the top-level dependencies, and put them into "private" or "non-private" buckets
            var excludeFromPublishPackagesLookup = new HashSet<string>(excludeFromPublishPackageIds, StringComparer.OrdinalIgnoreCase);
            foreach (var topLevelDependency in GetTopLevelDependencies())
            {
                if (!excludeFromPublishPackagesLookup.Contains(topLevelDependency))
                {
                    nonExcludeFromPublishAssetsToSearch.Push(topLevelDependency);
                    nonExcludeFromPublishAssets.Add(topLevelDependency);
                }
                else
                {
                    excludeFromPublishAssetsToSearch.Push(topLevelDependency);
                }
            }

            LockFileTargetLibrary library;
            string libraryName;

            // Walk all the non-private assets' dependencies and mark them as non-private
            while (nonExcludeFromPublishAssetsToSearch.Count > 0)
            {
                libraryName = nonExcludeFromPublishAssetsToSearch.Pop();
                if (libraryLookup.TryGetValue(libraryName, out library))
                {
                    foreach (var dependency in library.Dependencies)
                    {
                        if (!nonExcludeFromPublishAssets.Contains(dependency.Id))
                        {
                            nonExcludeFromPublishAssetsToSearch.Push(dependency.Id);
                            nonExcludeFromPublishAssets.Add(dependency.Id);
                        }
                    }
                }
            }

            // Go through assets marked private and their dependencies
            // For libraries not marked as non-private, mark them down as private
            var assetsToExcludeFromPublish = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
            while (excludeFromPublishAssetsToSearch.Count > 0)
            {
                libraryName = excludeFromPublishAssetsToSearch.Pop();
                if (libraryLookup.TryGetValue(libraryName, out library))
                {
                    assetsToExcludeFromPublish.Add(libraryName);

                    foreach (var dependency in library.Dependencies)
                    {
                        if (!nonExcludeFromPublishAssets.Contains(dependency.Id))
                        {
                            excludeFromPublishAssetsToSearch.Push(dependency.Id);
                        }
                    }
                }
            }

            return assetsToExcludeFromPublish;
        }
        private static HashSet<string> GetPackagesToBeFiltered(
          IDictionary<string, HashSet<PackageIdentity>> packagesToBeFiltered,
          IDictionary<string, LockFileTargetLibrary> packagesToBePublished)
        {
            var exclusionList = new HashSet<string>(StringComparer.OrdinalIgnoreCase);

            foreach (var entry in packagesToBePublished)
            {
                HashSet<PackageIdentity> librarySet;

                if (packagesToBeFiltered.TryGetValue(entry.Key, out librarySet))
                {
                    LockFileTargetLibrary dependency = entry.Value;
                    foreach (var library in librarySet)
                    {
                        if (dependency.Version.Equals(library.Version))
                        {
                            exclusionList.Add(entry.Key);
                            break;
                        }
                    }
                }
            }

            return exclusionList;
        }
    }
}
