using DlibDotNet;
using DlibDotNet.Dnn;
using View_by_Distance.Shared.Models.Stateless;

namespace View_by_Distance.FaceRecognitionDotNet.Dlib.Python;

internal sealed class CnnFaceDetectionModelV1
{

    #region Methods

    public static IEnumerable<MModRect> Detect(LossMmod net, Image image, int upsampleNumTimes)
    {
        using PyramidDown? pyr = new(2);
        List<MModRect>? rects = new();

        // Copy the data into dlib based objects
        using Matrix<RgbPixel>? matrix = new();
        Mode type = image.Mode;
        switch (type)
        {
            case Mode.Greyscale:
            case Mode.Rgb:
                DlibDotNet.Dlib.AssignImage(image.Matrix, matrix);
                break;
            default:
                throw new NotSupportedException("Unsupported image type, must be 8bit gray or RGB image.");
        }

        // Upsampling the image will allow us to detect smaller faces but will cause the
        // program to use more RAM and run longer.
        int levels = upsampleNumTimes;
        while (levels > 0)
        {
            levels--;
            DlibDotNet.Dlib.PyramidUp<PyramidDown>(matrix, 2);
        }

        OutputLabels<IEnumerable<MModRect>>? dets = net.Operator(matrix);

        // Scale the detection locations back to the original image size
        // if the image was upscaled.
        foreach (MModRect? d in dets.First())
        {
            DRectangle drect = pyr.RectDown(new DRectangle(d.Rect), (uint)upsampleNumTimes);
            d.Rect = new Rectangle((int)drect.Left, (int)drect.Top, (int)drect.Right, (int)drect.Bottom);
            rects.Add(d);
        }

        return rects;
    }

    public static IEnumerable<IEnumerable<MModRect>> DetectMulti(LossMmod net, IEnumerable<Image> images, int upsampleNumTimes, int batchSize = 128)
    {
        List<Matrix<RgbPixel>>? destImages = new();
        List<IEnumerable<MModRect>>? allRects = new();

        try
        {
            using PyramidDown? pyr = new(2);
            // Copy the data into dlib based objects
            foreach (Image? image in images)
            {
                Matrix<RgbPixel>? matrix = new();
                Mode type = image.Mode;
                switch (type)
                {
                    case Mode.Greyscale:
                    case Mode.Rgb:
                        DlibDotNet.Dlib.AssignImage(image.Matrix, matrix);
                        break;
                    default:
                        throw new NotSupportedException("Unsupported image type, must be 8bit gray or RGB image.");
                }

                for (int i = 0; i < upsampleNumTimes; i++)
                    DlibDotNet.Dlib.PyramidUp(matrix);

                destImages.Add(matrix);
            }

            for (int i = 1; i < destImages.Count; i++)
                if (destImages[i - 1].Columns != destImages[i].Columns || destImages[i - 1].Rows != destImages[i].Rows)
                    throw new ArgumentException("Images in list must all have the same dimensions.");

            OutputLabels<IEnumerable<MModRect>>? dets = net.Operator(destImages, (ulong)batchSize);
            foreach (IEnumerable<MModRect>? det in dets)
            {
                List<MModRect>? rects = new();
                foreach (MModRect? d in det)
                {
                    DRectangle drect = pyr.RectDown(new DRectangle(d.Rect), (uint)upsampleNumTimes);
                    d.Rect = new Rectangle((int)drect.Left, (int)drect.Top, (int)drect.Right, (int)drect.Bottom);
                    rects.Add(d);
                }

                allRects.Add(rects);
            }
        }
        finally
        {
            foreach (Matrix<RgbPixel>? matrix in destImages)
                matrix.Dispose();
        }

        return allRects;
    }

    #endregion

}