using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace View_by_Distance.ThumbHash.Models;

public static partial class ThumbHash
{

    private const int _MaxHash = 25;
    private const int _MinHash = 5;

    [DoesNotReturn]
    static void ThrowIfLessThan<T>(T value, T other, [CallerArgumentExpression(nameof(value))] string? paramName = null) => throw new ArgumentOutOfRangeException(paramName, value, $"'{value}' must be greater than or equal to '{other}'.");

    [DoesNotReturn]
    static void ThrowIfGreaterThan<T>(T value, T other, [CallerArgumentExpression(nameof(value))] string? paramName = null) => throw new ArgumentOutOfRangeException(paramName, value, $"'{paramName}' must be less than or equal to '{other}'.");

    [DoesNotReturn]
    static void ThrowNotEqual<T>(T value, T other, [CallerArgumentExpression(nameof(value))] string? paramName = null, [CallerArgumentExpression(nameof(other))] string? otherName = null) => throw new ArgumentOutOfRangeException(paramName, value, $"'{paramName}' must be equal to '{other}' ('{otherName}').");

    /// <summary>
    /// Encodes an RGBA image to a ThumbHash.
    /// </summary>
    /// <param name="width">The width of the input image. Must be ≤100px.</param>
    /// <param name="height">The height of the input image. Must be ≤100px.</param>
    /// <param name="rgba">The pixels in the input image, row-by-row. RGB should not be premultiplied by A. Must have `w*h*4` elements.</param>
    /// <returns>Byte array containing the ThumbHash</returns>
    public static byte[] RgbaToThumbHash(int width, int height, ReadOnlySpan<byte> rgba)
    {
        Span<byte> hash = stackalloc byte[_MaxHash];
        int bytesWritten = RgbaToThumbHash(hash, width, height, rgba);
        return hash[..bytesWritten].ToArray();
    }

    /// <summary>
    /// Encodes an RGBA image to a ThumbHash.
    /// </summary>
    /// <param name="hash"></param>
    /// <param name="w">The width of the input image. Must be ≤100px.</param>
    /// <param name="h">The height of the input image. Must be ≤100px.</param>
    /// <param name="rgba_bytes">The pixels in the input image, row-by-row. RGB should not be premultiplied by A. Must have `w*h*4` elements.</param>
    /// <returns>Number of bytes written into hash span</returns>
    public static int RgbaToThumbHash(Span<byte> hash, int w, int h, ReadOnlySpan<byte> rgba_bytes)
    {
        if (hash.Length < _MinHash)
            ThrowIfLessThan(hash.Length, _MinHash);

        // Encoding an image larger than 100x100 is slow with no benefit
        if (rgba_bytes.Length != w * h * 4)
            ThrowNotEqual(rgba_bytes.Length, w * h * 4);

        // Determine the average color
        float avg_r = 0.0f;
        float avg_g = 0.0f;
        float avg_b = 0.0f;
        float avg_a = 0.0f;

        ReadOnlySpan<RGBA> rgba = MemoryMarshal.Cast<byte, RGBA>(rgba_bytes);
        foreach (ref readonly RGBA pixel in rgba)
        {
            float alpha = pixel.A / 255.0f;
            avg_b += alpha / 255.0f * pixel.B;
            avg_g += alpha / 255.0f * pixel.G;
            avg_r += alpha / 255.0f * pixel.R;
            avg_a += alpha;
        }

        if (avg_a > 0.0f)
        {
            avg_r /= avg_a;
            avg_g /= avg_a;
            avg_b /= avg_a;
        }

        bool has_alpha = avg_a < (w * h);
        int l_limit = has_alpha ? 5 : 7; // Use fewer luminance bits if there's alpha
        int lx = Math.Max((int)MathF.Round(l_limit * w / MathF.Max(w, h)), 1);
        int ly = Math.Max((int)MathF.Round(l_limit * h / MathF.Max(w, h)), 1);

        using SpanOwner<float> l_owner = new(w * h); // l: luminance
        using SpanOwner<float> p_owner = new(w * h); // p: yellow - blue
        using SpanOwner<float> q_owner = new(w * h); // q: red - green
        using SpanOwner<float> a_owner = new(w * h); // a: alpha

        Span<float> l = l_owner.Span;
        Span<float> p = p_owner.Span;
        Span<float> q = q_owner.Span;
        Span<float> a = a_owner.Span;

        // Convert the image from RGBA to LPQA (composite atop the average color)
        int j = 0;
        foreach (ref readonly RGBA pixel in rgba)
        {
            float alpha = pixel.A / 255.0f;
            float b = avg_b * (1.0f - alpha) + alpha / 255.0f * pixel.B;
            float g = avg_g * (1.0f - alpha) + alpha / 255.0f * pixel.G;
            float r = avg_r * (1.0f - alpha) + alpha / 255.0f * pixel.R;
            a[j] = alpha;
            q[j] = r - g;
            p[j] = (r + g) / 2.0f - b;
            l[j] = (r + g + b) / 3.0f;
            j += 1;
        }

        // Encode using the DCT into DC (constant) and normalized AC (varying) terms
        Channel encode_channel(ReadOnlySpan<float> channel, int nx, int ny)
        {
            float dc = 0.0f;
            SpanOwner<float> ac_owner = new(nx * ny);
            float scale = 0.0f;

            Span<float> fx = stackalloc float[w];
            Span<float> ac = ac_owner.Span;
            int n = 0;
            for (int cy = 0; cy < ny; cy++)
            {
                int cx = 0;
                while (cx * ny < nx * (ny - cy))
                {
                    float f = 0.0f;
                    for (int x = 0; x < w; x++)
                    {
                        fx[x] = MathF.Cos(MathF.PI / w * cx * (x + 0.5f));
                    }
                    for (int y = 0; y < h; y++)
                    {
                        float fy = MathF.Cos(MathF.PI / h * cy * (y + 0.5f));
                        for (int x = 0; x < w; x++)
                        {
                            f += channel[x + y * w] * fx[x] * fy;
                        }
                    }
                    f /= w * h;
                    if (cx > 0 || cy > 0)
                    {
                        ac[n++] = f;
                        scale = MathF.Max(MathF.Abs(f), scale);
                    }
                    else
                    {
                        dc = f;
                    }
                    cx += 1;
                }
            }
            ac_owner = ac_owner.WithLength(n);
            ac = ac_owner.Span;

            if (scale > 0.0f)
            {
                foreach (ref float aci in ac)
                {
                    aci = 0.5f + 0.5f / scale * aci;
                }
            }

            return new Channel(dc, ac_owner, scale);
        };

        (float l_dc, SpanOwner<float> l_ac, float l_scale) = encode_channel(l, Math.Max(lx, 3), Math.Max(ly, 3));
        (float p_dc, SpanOwner<float> p_ac, float p_scale) = encode_channel(p, 3, 3);
        (float q_dc, SpanOwner<float> q_ac, float q_scale) = encode_channel(q, 3, 3);
        (float a_dc, SpanOwner<float> a_ac, float a_scale) = has_alpha ? encode_channel(a, 5, 5) : new Channel(1.0f, SpanOwner<float>.Empty, 1.0f);

        // Write the constants
        bool is_landscape = w > h;
        uint header24 = (uint)MathF.Round(63.0f * l_dc)
            | (((uint)MathF.Round(31.5f + 31.5f * p_dc)) << 6)
            | (((uint)MathF.Round(31.5f + 31.5f * q_dc)) << 12)
            | (((uint)MathF.Round(31.0f * l_scale)) << 18)
            | (has_alpha ? 1u << 23 : 0);
        int header16 = (ushort)(is_landscape ? ly : lx)
            | (((ushort)MathF.Round(63.0f * p_scale)) << 3)
            | (((ushort)MathF.Round(63.0f * q_scale)) << 9)
            | (is_landscape ? 1 << 15 : 0);

        int hi = 0;
        hash[hi++] = (byte)header24;
        hash[hi++] = (byte)(header24 >> 8);
        hash[hi++] = (byte)(header24 >> 16);
        hash[hi++] = (byte)header16;
        hash[hi++] = (byte)(header16 >> 8);
        if (has_alpha)
        {
            float fa_dc = MathF.Round(15.0f * a_dc);
            float fa_scale = MathF.Round(15.0f * a_scale);
            byte ia_dc = (byte)fa_dc;
            byte ia_scale = (byte)fa_scale;
            hash[hi++] = (byte)(ia_dc | (ia_scale << 4));
        }

        // Write the varying factors
        static void WriteFactor(ReadOnlySpan<float> ac, ref bool is_odd, ref int hi, Span<byte> hash)
        {
            for (int i = 0; i < ac.Length; i++)
            {
                byte u = (byte)MathF.Round(15.0f * ac[i]);
                if (is_odd)
                {
                    hash[hi - 1] |= (byte)(u << 4);
                }
                else
                {
                    hash[hi++] = u;
                }
                is_odd = !is_odd;
            }
        }

        using (l_ac)
        using (p_ac)
        using (q_ac)
        using (a_ac)
        {
            bool is_odd = false;
            WriteFactor(l_ac.Span, ref is_odd, ref hi, hash);
            WriteFactor(p_ac.Span, ref is_odd, ref hi, hash);
            WriteFactor(q_ac.Span, ref is_odd, ref hi, hash);
            if (has_alpha)
            {
                WriteFactor(a_ac.Span, ref is_odd, ref hi, hash);
            }
        }

        return hi;
    }

    /// <summary>
    /// Decodes a ThumbHash to an RGBA image.
    /// </summary>
    /// <returns>Width, height, and unpremultiplied RGBA8 pixels of the rendered ThumbHash.</returns>
    /// <exception cref="ArgumentOutOfRangeException">Thrown if the input is too short.</exception>
    public static byte[] ThumbHashToRgba(ReadOnlySpan<byte> hash, int w, int h)
    {
        using SpanOwner<byte> rgba_owner = new(w * h * 4);
        Span<byte> rgba = rgba_owner.Span;
        ThumbHashToRgba(hash, w, h, rgba);
        return rgba[..(w * h * 4)].ToArray();
    }

    /// <summary>
    /// Decodes a ThumbHash to an RGBA image.
    /// </summary>
    /// <returns>Width, height, and unpremultiplied RGBA8 pixels of the rendered ThumbHash.</returns>
    /// <exception cref="ArgumentOutOfRangeException">Thrown if the input is too short.</exception>
    /// <exception cref="ArgumentOutOfRangeException">Thrown if the RGBA span length is less than `w * h * 4` bytes.</exception>
    public static void ThumbHashToRgba(ReadOnlySpan<byte> hash, int w, int h, Span<byte> rgba)
    {
        // Read the constants
        uint header24 = hash[0]
            | (((uint)hash[1]) << 8)
            | (((uint)hash[2]) << 16);
        int header16 = hash[3] | (hash[4] << 8);
        float l_dc = (header24 & 63) / 63.0f;
        float p_dc = ((header24 >> 6) & 63) / 31.5f - 1.0f;
        float q_dc = ((header24 >> 12) & 63) / 31.5f - 1.0f;
        float l_scale = ((header24 >> 18) & 31) / 31.0f;
        bool has_alpha = (header24 >> 23) != 0;
        float p_scale = ((header16 >> 3) & 63) / 63.0f;
        float q_scale = ((header16 >> 9) & 63) / 63.0f;
        bool is_landscape = (header16 >> 15) != 0;
        int l_max = has_alpha ? 5 : 7;
        int lx = Math.Max(3, is_landscape ? l_max : header16 & 7);
        int ly = Math.Max(3, is_landscape ? header16 & 7 : l_max);
        (float a_dc, float a_scale) = has_alpha ? ((hash[5] & 15) / 15.0f, (hash[5] >> 4) / 15.0f) : (1.0f, 1.0f);

        // Read the varying factors (boost saturation by 1.25x to compensate for quantization)
        static SpanOwner<float> decode_channel(ReadOnlySpan<byte> hash, int start, ref int index, int nx, int ny, float scale)
        {
            SpanOwner<float> ac_owner = new(nx * ny);
            Span<float> ac = ac_owner.Span;
            int n = 0;
            for (int cy = 0; cy < ny; cy++)
            {
                for (int cx = cy > 0 ? 0 : 1; cx * ny < nx * (ny - cy); cx++, n++, index++)
                {
                    int data = hash[start + (index >> 1)] >> ((index & 1) << 2);
                    ac[n] = ((data & 15) / 7.5f - 1.0f) * scale;
                }
            }

            return ac_owner.WithLength(n);
        };

        // Decode using the DCT into RGB
        if (rgba.Length < w * h * 4)
            ThrowIfLessThan(rgba.Length, w * h * 4);

        int ac_start = has_alpha ? 6 : 5;
        int ac_index = 0;

        using SpanOwner<float> l_ac_owner = decode_channel(hash, ac_start, ref ac_index, lx, ly, l_scale);
        using SpanOwner<float> p_ac_owner = decode_channel(hash, ac_start, ref ac_index, 3, 3, p_scale * 1.25f);
        using SpanOwner<float> q_ac_owner = decode_channel(hash, ac_start, ref ac_index, 3, 3, q_scale * 1.25f);
        using SpanOwner<float> a_ac_owner = has_alpha ? decode_channel(hash, ac_start, ref ac_index, 5, 5, a_scale) : SpanOwner<float>.Empty;
        Span<float> l_ac = l_ac_owner.Span;
        Span<float> p_ac = p_ac_owner.Span;
        Span<float> q_ac = q_ac_owner.Span;
        Span<float> a_ac = a_ac_owner.Span;

        Span<float> fx = stackalloc float[7];
        Span<float> fy = stackalloc float[7];

        ref RGBA pixel = ref MemoryMarshal.AsRef<RGBA>(rgba);
        for (int y = 0; y < h; y++)
        {
            for (int x = 0; x < w; x++, pixel = ref Unsafe.AddByteOffset(ref pixel, 4))
            {
                float l = l_dc;
                float p = p_dc;
                float q = q_dc;
                float a = a_dc;

                // Precompute the coefficients
                for (int cx = 0; cx < Math.Max(lx, has_alpha ? 5 : 3); cx++)
                {
                    fx[cx] = MathF.Cos(MathF.PI / w * (x + 0.5f) * cx);
                }
                for (int cy = 0; cy < Math.Max(ly, has_alpha ? 5 : 3); cy++)
                {
                    fy[cy] = MathF.Cos(MathF.PI / h * (y + 0.5f) * cy);
                }

                // Decode L
                for (int cy = 0, j = 0; cy < ly; cy++)
                {
                    int cx = cy > 0 ? 0 : 1;
                    float fy2 = fy[cy] * 2.0f;
                    while (cx * ly < lx * (ly - cy))
                    {
                        l += l_ac[j] * fx[cx] * fy2;
                        j += 1;
                        cx += 1;
                    }
                }

                // Decode P and Q
                for (int cy = 0, j = 0; cy < 3; cy++)
                {
                    int cx = cy > 0 ? 0 : 1;
                    float fy2 = fy[cy] * 2.0f;
                    while (cx < 3 - cy)
                    {
                        float f = fx[cx] * fy2;
                        p += p_ac[j] * f;
                        q += q_ac[j] * f;
                        j += 1;
                        cx += 1;
                    }
                }

                // Decode A
                if (has_alpha)
                {
                    for (int cy = 0, j = 0; cy < 5; cy++)
                    {
                        int cx = cy > 0 ? 0 : 1;
                        float fy2 = fy[cy] * 2.0f;
                        while (cx < 5 - cy)
                        {
                            a += a_ac[j] * fx[cx] * fy2;
                            j += 1;
                            cx += 1;
                        }
                    }
                }

                // Convert to RGB
                float b = l - 2.0f / 3.0f * p;
                float r = (3.0f * l - b + q) / 2.0f;
                float g = r - q;

                pixel = new(
                    r: (byte)(Math.Clamp(r, 0.0f, 1.0f) * 255.0f),
                    g: (byte)(Math.Clamp(g, 0.0f, 1.0f) * 255.0f),
                    b: (byte)(Math.Clamp(b, 0.0f, 1.0f) * 255.0f),
                    a: (byte)(Math.Clamp(a, 0.0f, 1.0f) * 255.0f));
            }
        }
    }

    /// <summary>
    /// Extracts the average color from a ThumbHash.
    /// </summary>
    /// <returns>Unpremultiplied RGBA values where each value ranges from 0 to 1. </returns>
    /// <exception cref="NotImplementedException">Thrown if the input is too short.</exception>
    public static (float r, float g, float b, float a) ThumbHashToAverageRgba(ReadOnlySpan<byte> hash)
    {
        if (hash.Length < _MinHash)
            ThrowIfLessThan(hash.Length, _MinHash);

        uint header = hash[0] | ((uint)hash[1] << 8) | ((uint)hash[2] << 16);
        float l = (header & 63) / 63.0f;
        float p = ((header >> 6) & 63) / 31.5f - 1.0f;
        float q = ((header >> 12) & 63) / 31.5f - 1.0f;
        bool has_alpha = (header >> 23) != 0;
        float a = has_alpha ? (hash[5] & 15) / 15.0f : 1.0f;
        float b = l - 2.0f / 3.0f * p;
        float r = (3.0f * l - b + q) / 2.0f;
        float g = r - q;

        return (r: Math.Clamp(r, 0.0f, 1.0f),
                g: Math.Clamp(g, 0.0f, 1.0f),
                b: Math.Clamp(b, 0.0f, 1.0f),
                a);
    }

}