local vector = require("geometry.vector")

---@class Quaternion : Class
---@field x number
---@field y number
---@field z number
---@field w number
local Quaternion = {}

Quaternion.__index = Quaternion

function Quaternion:__add(q)
    if type(q) == "number" then
        return Quaternion(self.x + q, self.y + q, self.z + q, self.w + q)
    end
    return Quaternion(self.x + q.x, self.y + q.y, self.z + q.z, self.w + q.w)
end

function Quaternion:__sub(q)
    if type(q) == "number" then
        return Quaternion(self.x - q, self.y - q, self.z - q, self.w - q)
    end
    return Quaternion(self.x - q.x, self.y - q.y, self.z - q.z, self.w - q.w)
end

function Quaternion:__mul(q)
    if type(q) == "number" then
        return Quaternion(self.x * q, self.y * q, self.z * q, self.w * q)
    else
        return Quaternion(
            self.w * q.x + self.x * q.w + self.y * q.z - self.z * q.y,
            self.w * q.y + self.y * q.w + self.z * q.x - self.x * q.z,
            self.w * q.z + self.z * q.w + self.x * q.y - self.y * q.x,
            self.w * q.w - self.x * q.x - self.y * q.y - self.z * q.z
        )
    end
end

function Quaternion:__div(q)
    return Quaternion(self.x / q, self.y / q, self.z / q, self.w / q)
end

function Quaternion:__unm()
    return Quaternion(-self.x, -self.y, -self.z, -self.w)
end

function Quaternion:__len()
    return math.sqrt(self.x * self.x + self.y * self.y + self.z * self.z + self.w * self.w)
end

function Quaternion:__eq(q)
    return self.x == q.x and self.y == q.y and self.z == q.z and self.w == q.w
end

function Quaternion:__tostring()
    return string.format("(%f, %f, %f, %f)", self.x, self.y, self.z, self.w)
end

function Quaternion:length()
    return math.sqrt(self.x * self.x + self.y * self.y + self.z * self.z + self.w * self.w)
end

function Quaternion:length_squared()
    return self.x * self.x + self.y * self.y + self.z * self.z + self.w * self.w
end

function Quaternion:normalize()
    return self / #self
end

function Quaternion:conjugate()
    return Quaternion(-self.x, -self.y, -self.z, self.w)
end

function Quaternion:inverse()
    return self:conjugate() / self:length_squared()
end

function Quaternion:dot(q)
    return self.x * q.x + self.y * q.y + self.z * q.z + self.w * q.w
end

function Quaternion:lerp(q, t)
    return self * (1 - t) + q * t
end

function Quaternion:slerp(q, t)
    local dot = self:dot(q)
    if dot < 0 then
        q = -q
        dot = -dot
    end
    if dot > 0.9995 then
        return self:lerp(q, t)
    end
    local theta = math.acos(dot)
    return (self * math.sin((1 - t) * theta) + q * math.sin(t * theta)) / math.sin(theta)
end

function Quaternion:nlerp(q, t)
    return self:lerp(q, t):normalize()
end

function Quaternion:rotate(v)
    local qv = Quaternion(v.x, v.y, v.z, 0)
    local result = self * qv * self:conjugate()
    return vector(result.x, result.y, result.z)
end

function Quaternion:to_vector()
    return vector(self.x, self.y, self.z, self.w)
end

function Quaternion.from_axis_angle(axis, angle)
    local half_angle = angle / 2
    local s = math.sin(half_angle)
    return Quaternion(axis.x * s, axis.y * s, axis.z * s, math.cos(half_angle))
end

function Quaternion:to_axis_angle()
    local half_angle = math.acos(self.w)
    local s = math.sin(half_angle)
    if s < 0.0001 then
        return vector(1, 0, 0), 0
    else
        return vector(self.x / s, self.y / s, self.z / s), half_angle * 2
    end
end

function Quaternion.from_euler(euler)
    local c1 = math.cos(euler.z / 2)
    local c2 = math.cos(euler.y / 2)
    local c3 = math.cos(euler.x / 2)
    local s1 = math.sin(euler.z / 2)
    local s2 = math.sin(euler.y / 2)
    local s3 = math.sin(euler.x / 2)
    return Quaternion(
        s1 * c2 * c3 + c1 * s2 * s3,
        c1 * s2 * c3 - s1 * c2 * s3,
        c1 * c2 * s3 + s1 * s2 * c3,
        c1 * c2 * c3 - s1 * s2 * s3
    )
end

function Quaternion:to_euler()
    local sqw = self.w * self.w
    local sqx = self.x * self.x
    local sqy = self.y * self.y
    local sqz = self.z * self.z
    local unit = sqx + sqy + sqz + sqw
    local test = self.x * self.y + self.z * self.w
    if test > 0.499 * unit then
        return vector(
            2 * math.atan2(self.x, self.w),
            math.pi / 2,
            0
        )
    end
    if test < -0.499 * unit then
        return vector(
            -2 * math.atan2(self.x, self.w),
            -math.pi / 2,
            0
        )
    end
    return vector(
        math.atan2(2 * self.y * self.w - 2 * self.x * self.z, sqx - sqy - sqz + sqw),
        math.asin(2 * test / unit),
        math.atan2(2 * self.x * self.w - 2 * self.y * self.z, -sqx + sqy - sqz + sqw)
    )
end

setmetatable(Quaternion, {
    __call = function(self, x, y, z, w)
        if type(x) == "table" then
            if x.axis then
                return Quaternion.from_axis_angle(x.axis, x.angle)
            end
            if x.yaw then
                return Quaternion.from_euler(vector(x.yaw, x.pitch, x.roll))
            end
            if not x.w then
                return Quaternion.from_euler(vector(x.x, x.y, x.z))
            end
            return setmetatable({
                x = x.x or 0,
                y = x.y or 0,
                z = x.z or 0,
                w = x.w or 0
            }, Quaternion)
        end
        return setmetatable({
            x = x,
            y = y,
            z = z,
            w = w
        }, Quaternion)
    end
})

Quaternion.identity = Quaternion(0, 0, 0, 1)

return Quaternion