#pragma once
/*
 *  Copyright (C) 2024  Brett Terpstra
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

#ifndef BLT_GRAPHICS_RAYCAST_H
#define BLT_GRAPHICS_RAYCAST_H

#include <blt/math/vectors.h>
#include <blt/math/matrix.h>
#include <type_traits>

namespace blt::gfx
{
    template<typename... Args>
    inline const bool not_float_v = (std::is_arithmetic_v<Args> && ...) && (!std::is_same_v<float, Args> && ...);
    
    blt::vec3 calculateRay3D(float mx, float my, float width, float height, const blt::mat4x4& view, const blt::mat4x4& proj);
    
    blt::vec3 calculateRay3D(float width, float height, const blt::mat4x4& view, const blt::mat4x4& proj);
    
    blt::vec3 calculateRay2D(float mx, float my, float width, float height, const blt::vec3 scale, const blt::mat4x4& view, const blt::mat4x4& proj);
    
    blt::vec3 calculateRay2D(float width, float height, const blt::vec3 scale, const blt::mat4x4& view, const blt::mat4x4& proj);
    
    template<typename T, typename G, std::enable_if_t<not_float_v<T, G>, bool> = true>
    inline blt::vec3 calculateRay2D(T width, G height, const blt::vec3& scale, const blt::mat4x4& view, const blt::mat4x4& proj)
    {
        return calculateRay2D(static_cast<float>(width), static_cast<float>(height), scale, view, proj);
    }
    
    template<typename T, typename G, typename V, typename N, std::enable_if_t<not_float_v<T, G, V, N>, bool> = true>
    inline blt::vec3 calculateRay2D(T mx, G my, V width, N height, const blt::vec3& scale, const blt::mat4x4& view, const blt::mat4x4& proj)
    {
        return calculateRay2D(static_cast<float>(mx), static_cast<float>(my),
                              static_cast<float>(width), static_cast<float>(height), scale, view, proj);
    }
    
    template<typename T, typename G, std::enable_if_t<not_float_v<T, G>, bool> = true>
    inline blt::vec3 calculateRay3D(T width, G height, const blt::mat4x4& view, const blt::mat4x4& proj)
    {
        return calculateRay3D(static_cast<float>(width), static_cast<float>(height), view, proj);
    }
    
    template<typename T, typename G, typename V, typename N, std::enable_if_t<not_float_v<T, G, V, N>, bool> = true>
    inline blt::vec3 calculateRay3D(T mx, G my, V width, N height, const blt::mat4x4& view, const blt::mat4x4& proj)
    {
        return calculateRay3D(static_cast<float>(mx), static_cast<float>(my), static_cast<float>(width), static_cast<float>(height), view, proj);
    }
    
    namespace detail
    {
        blt::vec3 toWorldCoords(const blt::vec4& eyeCoords, const blt::mat4x4& view);
        
        blt::vec4 toEyeCoords(const blt::vec4& clipCoords, const blt::mat4x4& proj);
        
        inline blt::vec2 getNormalisedDeviceCoordinates(float mx, float my, float width, float height)
        {
            float x = (2.0f * mx) / width - 1.0f;
            float y = (2.0f * my) / height - 1.0f;
            return blt::vec2(x, y);
        }
    }
    
}

#endif //BLT_GRAPHICS_RAYCAST_H