using DlibDotNet;
using DlibDotNet.Dnn;

namespace View_by_Distance.FaceRecognitionDotNet.Dlib.Python;

internal sealed class FaceRecognitionModelV1
{

    #region Methods

    public static Matrix<double> ComputeFaceDescriptor(LossMetric net, Image img, FullObjectDetection face, int numberOfJitters)
    {
        FullObjectDetection[]? faces = new[] { face };
        return ComputeFaceDescriptors(net, img, faces, numberOfJitters).First();
    }

    public static IEnumerable<Matrix<double>> ComputeFaceDescriptors(LossMetric net, Image img, IEnumerable<FullObjectDetection> faces, int numberOfJitters)
    {
        Image[]? batchImage = new[] { img };
        IEnumerable<FullObjectDetection>[]? batchFaces = new[] { faces };
        return BatchComputeFaceDescriptors(net, batchImage, batchFaces, numberOfJitters).First();
    }

    public static IEnumerable<IEnumerable<Matrix<double>>> BatchComputeFaceDescriptors(LossMetric net,
                                                                                       IList<Image> batchImages,
                                                                                       IList<IEnumerable<FullObjectDetection>> batchFaces,
                                                                                       int numberOfJitters)
    {
        if (batchImages.Count != batchFaces.Count)
            throw new ArgumentException("The array of images and the array of array of locations must be of the same size");

        foreach (IEnumerable<FullObjectDetection>? faces in batchFaces)
            foreach (FullObjectDetection? f in faces)
            {
                if (f.Parts is not 68 and not 5)
                    throw new ArgumentException("The full_object_detection must use the iBUG 300W 68 point face landmark style or dlib's 5 point style.");
            }

        List<Array<Matrix<RgbPixel>>>? faceChipsArray = new(batchImages.Count);
        List<Matrix<RgbPixel>>? faceChips = new();
        for (int i = 0; i < batchImages.Count; ++i)
        {
            IEnumerable<FullObjectDetection>? faces = batchFaces[i];
            Image? img = batchImages[i];

            List<ChipDetails>? dets = new(faces.Count());
            foreach (FullObjectDetection? f in faces)
                dets.Add(DlibDotNet.Dlib.GetFaceChipDetails(f, 150, 0.25));

            Array<Matrix<RgbPixel>>? thisImageFaceChips = DlibDotNet.Dlib.ExtractImageChips<RgbPixel>(img.Matrix, dets);
            foreach (Matrix<RgbPixel>? chip in thisImageFaceChips)
                faceChips.Add(chip);
            faceChipsArray.Add(thisImageFaceChips);

            foreach (ChipDetails? det in dets)
                det.Dispose();
        }

        List<List<Matrix<double>>>? faceDescriptors = new();
        for (int i = 0, count = batchImages.Count; i < count; i++)
            faceDescriptors.Add(new List<Matrix<double>>());

        if (numberOfJitters <= 1)
        {
            // extract descriptors and convert from float vectors to double vectors
            OutputLabels<Matrix<float>>? descriptors = net.Operator(faceChips, 16);
            int index = 0;
            Matrix<float>[]? list = descriptors.Select(matrix => matrix).ToArray();
            for (int i = 0; i < batchFaces.Count; ++i)
                for (int j = 0; j < batchFaces[i].Count(); ++j)
                    faceDescriptors[i].Add(DlibDotNet.Dlib.MatrixCast<double>(list[index++]));

            if (index != list.Length)
                throw new ApplicationException();
        }
        else
        {
            // extract descriptors and convert from float vectors to double vectors
            int index = 0;
            for (int i = 0; i < batchFaces.Count; ++i)
                for (int j = 0; j < batchFaces[i].Count(); ++j)
                {
                    Matrix<RgbPixel>[]? tmp = JitterImage(faceChips[index++], numberOfJitters).ToArray();
                    using (OutputLabels<Matrix<float>>? tmp2 = net.Operator(tmp, 16))
                    using (MatrixOp? mat = DlibDotNet.Dlib.Mat(tmp2))
                    {
                        Matrix<double>? r = DlibDotNet.Dlib.Mean<double>(mat);
                        faceDescriptors[i].Add(r);
                    }

                    foreach (Matrix<RgbPixel>? matrix in tmp)
                        matrix.Dispose();
                }

            if (index != faceChips.Count)
                throw new ApplicationException();
        }

        if (faceChipsArray.Count > 0)
        {
            foreach (Array<Matrix<RgbPixel>>? array in faceChipsArray)
            {
                foreach (Matrix<RgbPixel>? faceChip in array)
                    faceChip.Dispose();
                array.Dispose();
            }
        }

        return faceDescriptors;
    }

    #region Helpers

    private static readonly Rand _Rand = new();

    private static IEnumerable<Matrix<RgbPixel>> JitterImage(Matrix<RgbPixel> img, int numberOfJitters)
    {
        List<Matrix<RgbPixel>>? crops = new();
        for (int i = 0; i < numberOfJitters; ++i)
            crops.Add(DlibDotNet.Dlib.JitterImage(img, _Rand));

        return crops;
    }

    #endregion

    #endregion

}