using DlibDotNet;
using DlibDotNet.Dnn;
using System.Collections.ObjectModel;
using System.Drawing;
using System.Drawing.Imaging;
using System.Runtime.InteropServices;
using View_by_Distance.FaceRecognitionDotNet.Dlib.Python;
using View_by_Distance.FaceRecognitionDotNet.Extensions;
using View_by_Distance.Shared.Models;
using View_by_Distance.Shared.Models.Stateless;

namespace View_by_Distance.FaceRecognitionDotNet.Models;

public class FaceRecognition : DisposableObject
{

    public FaceDetector? CustomFaceDetector { get; set; }
    public FaceLandmarkDetector? CustomFaceLandmarkDetector { get; set; }

    private readonly Model _Model;
    private readonly int _NumberOfJitters;
    private readonly LossMetric _FaceEncoder;
    private readonly LossMmod _CnnFaceDetector;
    private readonly int _NumberOfTimesToUpsample;
    private readonly PredictorModel _PredictorModel;
    private readonly FrontalFaceDetector _FaceDetector;
    private readonly ShapePredictor _PosePredictor5Point;
    private readonly ShapePredictor _PosePredictor68Point;

    private record Record(Location Location, List<FaceEncoding?> FaceEncodings, List<List<FacePartAndFacePointArray>> FaceParts);

    public FaceRecognition(int numberOfJitters, int numberOfTimesToUpsample, Model model, ModelParameter modelParameter, PredictorModel predictorModel)
    {
        if (modelParameter is null)
            throw new NullReferenceException(nameof(modelParameter));
        if (modelParameter.PosePredictor5FaceLandmarksModel is null)
            throw new NullReferenceException(nameof(modelParameter.PosePredictor5FaceLandmarksModel));
        if (modelParameter.PosePredictor68FaceLandmarksModel is null)
            throw new NullReferenceException(nameof(modelParameter.PosePredictor68FaceLandmarksModel));
        if (modelParameter.CnnFaceDetectorModel is null)
            throw new NullReferenceException(nameof(modelParameter.CnnFaceDetectorModel));
        if (modelParameter.FaceRecognitionModel is null)
            throw new NullReferenceException(nameof(modelParameter.FaceRecognitionModel));
        _Model = model;
        _PredictorModel = predictorModel;
        _NumberOfJitters = numberOfJitters;
        _NumberOfTimesToUpsample = numberOfTimesToUpsample;
        _FaceDetector?.Dispose();
        _FaceDetector = DlibDotNet.Dlib.GetFrontalFaceDetector();
        _PosePredictor68Point?.Dispose();
        _PosePredictor68Point = ShapePredictor.Deserialize(modelParameter.PosePredictor68FaceLandmarksModel);
        _PosePredictor5Point?.Dispose();
        _PosePredictor5Point = ShapePredictor.Deserialize(modelParameter.PosePredictor5FaceLandmarksModel);
        _CnnFaceDetector?.Dispose();
        _CnnFaceDetector = LossMmod.Deserialize(modelParameter.CnnFaceDetectorModel);
        _FaceEncoder?.Dispose();
        _FaceEncoder = LossMetric.Deserialize(modelParameter.FaceRecognitionModel);
    }

    public static double FaceDistance(FaceEncoding faceEncoding, FaceEncoding faceToCompare)
    {
        if (faceEncoding is null)
            throw new NullReferenceException(nameof(faceEncoding));
        if (faceToCompare is null)
            throw new NullReferenceException(nameof(faceToCompare));
        faceEncoding.ThrowIfDisposed();
        faceToCompare.ThrowIfDisposed();
        if (faceEncoding.Encoding.Size == 0)
            return 0;
        using Matrix<double>? diff = faceEncoding.Encoding - faceToCompare.Encoding;
        return DlibDotNet.Dlib.Length(diff);
    }

    private static FacePoint[] Join(IEnumerable<FacePoint> facePoints1, IEnumerable<FacePoint> facePoints2)
    {
        List<FacePoint> results = [.. facePoints1, .. facePoints2];
        return results.ToArray();
    }

    private List<FacePartAndFacePointArray> GetFaceParts(FullObjectDetection fullObjectDetection)
    {
        List<FacePartAndFacePointArray> results = [];
        FacePoint[] facePoints = Enumerable.Range(0, (int)fullObjectDetection.Parts)
                                 .Select(index => new FacePoint(index, fullObjectDetection.GetPart((uint)index).X, fullObjectDetection.GetPart((uint)index).Y))
                                 .ToArray();
        switch (_PredictorModel)
        {
            case PredictorModel.Custom:
                throw new NotImplementedException();
            case PredictorModel.Large:
                if (facePoints.Length == 68)
                {
                    results.Add(new FacePartAndFacePointArray(FacePart.Chin, facePoints.Skip(0).Take(17).ToArray()));
                    results.Add(new FacePartAndFacePointArray(FacePart.LeftEyebrow, facePoints.Skip(17).Take(5).ToArray()));
                    results.Add(new FacePartAndFacePointArray(FacePart.RightEyebrow, facePoints.Skip(22).Take(5).ToArray()));
                    results.Add(new FacePartAndFacePointArray(FacePart.NoseBridge, facePoints.Skip(27).Take(5).ToArray()));
                    results.Add(new FacePartAndFacePointArray(FacePart.NoseTip, facePoints.Skip(31).Take(5).ToArray()));
                    results.Add(new FacePartAndFacePointArray(FacePart.LeftEye, facePoints.Skip(36).Take(6).ToArray()));
                    results.Add(new FacePartAndFacePointArray(FacePart.RightEye, facePoints.Skip(42).Take(6).ToArray()));
                    results.Add(new FacePartAndFacePointArray(FacePart.TopLip, Join(facePoints.Skip(48).Take(7), facePoints.Skip(60).Take(5))));
                    results.Add(new FacePartAndFacePointArray(FacePart.BottomLip, Join(facePoints.Skip(55).Take(5), facePoints.Skip(65).Take(3))));
                }
                break;
            case PredictorModel.Small:
                if (facePoints.Length == 5)
                {
                    results.Add(new FacePartAndFacePointArray(FacePart.RightEye, facePoints.Skip(0).Take(2).ToArray()));
                    results.Add(new FacePartAndFacePointArray(FacePart.LeftEye, facePoints.Skip(2).Take(2).ToArray()));
                    results.Add(new FacePartAndFacePointArray(FacePart.NoseTip, facePoints.Skip(4).Take(1).ToArray()));
                }
                break;
            default:
                break;
        }
        return results;
    }

    private MModRect[] GetMModRects(Image image)
    {
        switch (_Model)
        {
            case Model.Cnn:
                return CnnFaceDetectionModelV1.Detect(_CnnFaceDetector, image, _NumberOfTimesToUpsample).ToArray();
            case Model.Hog:
                IEnumerable<Tuple<DlibDotNet.Rectangle, double>>? locations = SimpleObjectDetector.RunDetectorWithUpscale2(_FaceDetector, image, (uint)_NumberOfTimesToUpsample);
                return locations.Select(l => new MModRect { Rect = l.Item1, DetectionConfidence = l.Item2 }).ToArray();
            case Model.Custom:
                if (CustomFaceDetector is null)
                    throw new NotSupportedException("The custom face detector is not ready.");
                return CustomFaceDetector.Detect(image, _NumberOfTimesToUpsample).Select(rect => new MModRect
                {
                    Rect = new DlibDotNet.Rectangle(rect.Left, rect.Top, rect.Right, rect.Bottom),
                    DetectionConfidence = rect.Confidence
                }).ToArray();
            default:
                throw new Exception();
        }
    }

    public List<Location> FaceLocations(Image image)
    {
        if (image is null)
            throw new NullReferenceException(nameof(image));
        image.ThrowIfDisposed();
        ThrowIfDisposed();
        List<Location> results = [];
        System.Drawing.Rectangle rectangle;
        IEnumerable<MModRect> mModRects = GetMModRects(image);
        foreach (MModRect? mModRect in mModRects)
        {
            rectangle = new(mModRect.Rect.Left, mModRect.Rect.Top, (int)mModRect.Rect.Width, (int)mModRect.Rect.Height);
            Location location = ILocation.TrimBound(mModRect.DetectionConfidence, rectangle, image.Width, image.Height, mModRects.Count());
            mModRect.Dispose();
            results.Add(location);
        }
        return results;
    }

    private List<FullObjectDetection> GetFullObjectDetections(Image image, List<Location> locations)
    {
        List<FullObjectDetection> results = [];
        if (_PredictorModel == PredictorModel.Custom)
        {
            if (CustomFaceLandmarkDetector is null)
                throw new NullReferenceException(nameof(CustomFaceLandmarkDetector));
            foreach (Location location in locations)
            {
                FullObjectDetection fullObjectDetection = CustomFaceLandmarkDetector.Detect(image, location);
                results.Add(fullObjectDetection);
            }
        }
        else
        {
            ShapePredictor posePredictor = _PredictorModel switch
            {
                PredictorModel.Large => _PosePredictor68Point,
                PredictorModel.Small => _PosePredictor5Point,
                PredictorModel.Custom => throw new NotImplementedException(),
                _ => throw new Exception()
            };
            foreach (Location location in locations)
            {
                DlibDotNet.Rectangle rectangle = new(location.Left, location.Top, location.Right, location.Bottom);
                FullObjectDetection fullObjectDetection = posePredictor.Detect(image.Matrix, rectangle);
                results.Add(fullObjectDetection);
            }
        }
        return results;
    }

    private List<Location> GetLocations(Image image)
    {
        List<Location> results = [];
        MModRect[] mModRects = GetMModRects(image);
        if (mModRects.Length != 0)
        {
            Location location;
            System.Drawing.Rectangle rectangle;
            foreach (MModRect? mModRect in mModRects)
            {
                rectangle = new(mModRect.Rect.Left, mModRect.Rect.Top, (int)mModRect.Rect.Width, (int)mModRect.Rect.Height);
                location = ILocation.TrimBound(mModRect.DetectionConfidence, rectangle, image.Width, image.Height, mModRects.Length);
                mModRect.Dispose();
                results.Add(location);
            }
        }
        return results;
    }

    public List<FaceRecognitionGroup> GetCollection(Image image, List<Location> locations, bool includeFaceEncoding, bool includeFaceParts)
    {
        List<FaceRecognitionGroup> results = [];
        if (image is null)
            throw new NullReferenceException(nameof(image));
        image.ThrowIfDisposed();
        ThrowIfDisposed();
        if (_PredictorModel == PredictorModel.Custom)
            throw new NotSupportedException("FaceRecognition.PredictorModel.Custom is not supported.");
        if (locations.Count == 0)
            locations.AddRange(GetLocations(image));
        List<FullObjectDetection> fullObjectDetections = GetFullObjectDetections(image, locations);
        if (fullObjectDetections.Count != locations.Count)
            throw new Exception();
        Record record;
        List<Record> records = [];
        foreach (Location location in locations)
        {
            record = new(location, [], []);
            records.Add(record);
        }
        if (locations.Count != records.Count)
            throw new Exception();
        if (!includeFaceEncoding)
        {
            for (int i = 0; i < records.Count; i++)
                records[i].FaceEncodings.Add(null);
        }
        else
        {
            Matrix<double> doubles;
            FaceEncoding faceEncoding;
            for (int i = 0; i < records.Count; i++)
            {
                doubles = FaceRecognitionModelV1.ComputeFaceDescriptor(_FaceEncoder, image, fullObjectDetections[i], _NumberOfJitters);
                faceEncoding = new(doubles);
                records[i].FaceEncodings.Add(faceEncoding);
            }
        }
        if (!includeFaceParts)
        {
            for (int i = 0; i < records.Count; i++)
                records[i].FaceParts.Add([]);
        }
        else
        {
            List<FacePartAndFacePointArray> faceParts;
            for (int i = 0; i < records.Count; i++)
            {
                faceParts = GetFaceParts(fullObjectDetections[i]);
                records[i].FaceParts.Add(faceParts);
            }
        }
        foreach (FullObjectDetection fullObjectDetection in fullObjectDetections)
            fullObjectDetection.Dispose();
        const int indexZero = 0;
        FaceRecognitionGroup faceRecognitionGroupB;
        Dictionary<FacePart, FacePoint[]> keyValuePairs;
        foreach (Record r in records)
        {
            if (r.FaceEncodings.Count != 1 || r.FaceParts.Count != 1)
                continue;
            if (r.FaceParts[indexZero].Count == 0)
                faceRecognitionGroupB = new(r.Location, r.FaceEncodings[indexZero], null);
            else
            {
                keyValuePairs = [];
                foreach (FacePartAndFacePointArray facePartAndFacePointArray in r.FaceParts[indexZero])
                    keyValuePairs.Add(facePartAndFacePointArray.FacePart, facePartAndFacePointArray.FacePoints);
                faceRecognitionGroupB = new(r.Location, r.FaceEncodings[indexZero], keyValuePairs);
            }
            results.Add(faceRecognitionGroupB);
        }
        return results;
    }

    public static FaceEncoding LoadFaceEncoding(double[] encoding)
    {
        if (encoding is null)
            throw new NullReferenceException(nameof(encoding));
        if (encoding.Length != 128)
        {
            string message = $"{nameof(encoding)}.{nameof(encoding.Length)} must be 128.";
            throw new ArgumentOutOfRangeException(message);
        }
#pragma warning disable
        Matrix<double>? matrix = Matrix<double>.CreateTemplateParameterizeMatrix(0, 1);
#pragma warning restore
        matrix.SetSize(128);
        matrix.Assign(encoding);
        return new FaceEncoding(matrix);
    }

    public static FaceEncoding LoadBFaceEncoding(double[] encoding)
    {
        if (encoding is null)
            throw new NullReferenceException(nameof(encoding));
        if (encoding.Length != 512)
        {
            string message = $"{nameof(encoding)}.{nameof(encoding.Length)} must be 512.";
            throw new ArgumentOutOfRangeException(message);
        }
#pragma warning disable
        Matrix<double>? matrix = Matrix<double>.CreateTemplateParameterizeMatrix(0, 1);
#pragma warning restore
        matrix.SetSize(512);
        matrix.Assign(encoding);
        return new FaceEncoding(matrix);
    }

    public static Image LoadImageFile(string file, Mode mode = Mode.Rgb)
    {
        if (!File.Exists(file))
            throw new FileNotFoundException(file);
        return mode switch
        {
            Mode.Rgb => new Image(DlibDotNet.Dlib.LoadImageAsMatrix<RgbPixel>(file), mode),
            Mode.Greyscale => new Image(DlibDotNet.Dlib.LoadImageAsMatrix<byte>(file), mode),
            _ => throw new NotImplementedException()
        };
    }

#pragma warning disable CA1416

    public static Image? LoadImage(Bitmap bitmap)
    {
        Mode mode;
        int dstChannel;
        int srcChannel;
        int width = bitmap.Width;
        int height = bitmap.Height;
        PixelFormat format = bitmap.PixelFormat;
        System.Drawing.Rectangle rect = new(0, 0, width, height);
#pragma warning disable IDE0010
        switch (format)
        {
            case PixelFormat.Format8bppIndexed:
                mode = Mode.Greyscale;
                srcChannel = 1;
                dstChannel = 1;
                break;
            case PixelFormat.Format24bppRgb:
                mode = Mode.Rgb;
                srcChannel = 3;
                dstChannel = 3;
                break;
            case PixelFormat.Format32bppRgb:
            case PixelFormat.Format32bppArgb:
                mode = Mode.Rgb;
                srcChannel = 4;
                dstChannel = 3;
                break;
            default:
                throw new ArgumentOutOfRangeException($"{nameof(bitmap)}", $"The specified {nameof(PixelFormat)} is not supported.");
        }
#pragma warning restore IDE0010
        BitmapData? data = null;
        try
        {
            data = bitmap.LockBits(rect, ImageLockMode.ReadOnly, format);
            unsafe
            {
                byte[]? array = new byte[width * height * dstChannel];
                fixed (byte* pArray = &array[0])
                {
                    byte* dst = pArray;

                    switch (srcChannel)
                    {
                        case 1:
                            {
                                IntPtr src = data.Scan0;
                                int stride = data.Stride;

                                for (int h = 0; h < height; h++)
                                    Marshal.Copy(IntPtr.Add(src, h * stride), array, h * width, width * dstChannel);
                            }
                            break;
                        case 3:
                        case 4:
                            {
                                byte* src = (byte*)data.Scan0;
                                int stride = data.Stride;
                                for (int h = 0; h < height; h++)
                                {
                                    int srcOffset = h * stride;
                                    int dstOffset = h * width * dstChannel;
                                    for (int w = 0; w < width; w++)
                                    {
                                        // BGR order to RGB order
                                        dst[dstOffset + (w * dstChannel) + 0] = src[srcOffset + (w * srcChannel) + 2];
                                        dst[dstOffset + (w * dstChannel) + 1] = src[srcOffset + (w * srcChannel) + 1];
                                        dst[dstOffset + (w * dstChannel) + 2] = src[srcOffset + (w * srcChannel) + 0];
                                    }
                                }
                            }
                            break;
                        default:
                            break;
                    }
                    IntPtr ptr = (IntPtr)pArray;
                    switch (mode)
                    {
                        case Mode.Rgb:
                            return new Image(new Matrix<RgbPixel>(ptr, height, width, width * 3), Mode.Rgb);
                        case Mode.Greyscale:
                            return new Image(new Matrix<byte>(ptr, height, width, width), Mode.Greyscale);
                        default:
                            break;
                    }
                }
            }
        }
        finally
        {
            if (data != null)
                bitmap.UnlockBits(data);
        }
        return null;
    }

    public static ReadOnlyCollection<LocationContainer> GetLocationContainers(int permyriad, ReadOnlyCollection<LocationContainer> locationContainers, LocationContainer locationContainer)
    {
        List<LocationContainer> results = [];
        int lengthPermyriad;
        if (locationContainers.Count != 0)
        {
            double length;
            LocationContainer result;
            if (locationContainer.Encoding is not FaceEncoding faceEncodingToCompare)
                throw new NullReferenceException(nameof(locationContainer));
            faceEncodingToCompare.ThrowIfDisposed();
            foreach (LocationContainer l in locationContainers)
            {
#pragma warning disable CA1513
                if (l.Encoding is not FaceEncoding faceEncoding || faceEncoding.IsDisposed)
                    throw new ObjectDisposedException($"{nameof(l)} contains disposed object.");
#pragma warning restore CA1513
                using (Matrix<double> diff = faceEncoding.Encoding - faceEncodingToCompare.Encoding)
                    length = DlibDotNet.Dlib.Length(diff);
                lengthPermyriad = (int)(length * permyriad);
                result = LocationContainer.Get(locationContainer, l, lengthPermyriad, keepExifDirectory: false, keepEncoding: false);
                results.Add(result);
            }
        }
        LocationContainer[] array = results.OrderBy(l => l.LengthPermyriad).ToArray();
        return array.AsReadOnly();
    }

    public static List<FaceDistance> FaceDistances(ReadOnlyCollection<FaceDistance> faceDistances, FaceDistance faceDistanceToCompare)
    {
        List<FaceDistance> results = [];
        if (faceDistances.Count != 0)
        {
            double length;
            FaceDistance result;
            if (faceDistanceToCompare.Encoding is not FaceEncoding faceEncodingToCompare)
                throw new NullReferenceException(nameof(faceDistanceToCompare));
            faceEncodingToCompare.ThrowIfDisposed();
            foreach (FaceDistance faceDistance in faceDistances)
            {
#pragma warning disable CA1513
                if (faceDistance.Encoding is not FaceEncoding faceEncoding || faceEncoding.IsDisposed)
                    throw new ObjectDisposedException($"{nameof(faceDistances)} contains disposed object.");
#pragma warning restore CA1513
                using (Matrix<double> diff = faceEncoding.Encoding - faceEncodingToCompare.Encoding)
                    length = DlibDotNet.Dlib.Length(diff);
                result = new(faceDistance, length);
                results.Add(result);
            }
        }
        return results;
    }

#pragma warning restore CA1416

    protected override void DisposeUnmanaged()
    {
        base.DisposeUnmanaged();
        _PosePredictor68Point?.Dispose();
        _PosePredictor5Point?.Dispose();
        _CnnFaceDetector?.Dispose();
        _FaceEncoder?.Dispose();
        _FaceDetector?.Dispose();
    }

}