local vector = require("util.geometry.vector")
local quaternion = require("util.geometry.quaternion")
local matrix = require("util.geometry.matrix")

local collider = {
    types = {},
    force = require("physics.force"),
}

function collider.new(data)
    if not collider.types[data.type] then
        error("Unknown collider type: " .. data.type)
    end

    local self = {}

    for i, v in pairs(data) do
        self[i] = v
    end
    for i, v in pairs(collider.types[data.type]) do
        self[i] = v
    end

    self.position = vector.from(self.position or { x = 0, y = 0, z = 0 })
    self.center_of_mass = vector.from(self.center_of_mass or { x = 0, y = 0, z = 0 })
    self.velocity = vector.from(self.velocity or { x = 0, y = 0, z = 0 })
    self.orientation = quaternion(self.orientation or { x = 0, y = 0, z = 0, w = 1 })
    self.rotation = quaternion(self.rotation or { x = 0, y = 0, z = 0, w = 0 })
    self.scale = vector.from(self.scale or { x = 1, y = 1, z = 1 })
    self.mass = self.mass or self.scale:length()^3
    self.inverse_mass = 1 / self.mass
    self.elasticity = self.elasticity or 0.5
    self.friction = self.friction or 0.5
    self.force = vector.from(self.force or { x = 0, y = 0, z = 0 })
    self.torque = quaternion(self.torque or { x = 0, y = 0, z = 0 })
    self.inverse_inertia = matrix(3, 3)
    self.fluid = "air"

    if self.init then
        self:init()
    end

    return self
end

function collider.register(type, data)
    collider.types[type] = data
    collider.types[type .. "_" .. type] = collider.types[type]
end

-- helper function to check if a table contains a specific element
local function contains(b, c)
    for i = 1, #b do
        if c(b[i]) then
            return true
        end
    end
    return false
end

function collider.update(colliders, dt)
    local collisions = {}
    for i = 1, #colliders do
        local c = colliders[i]

        if c.dormant then
            continue
        end

        c:ground(dt)

        for j = 1, #colliders do
            if i ~= j then
                local other = colliders[j]
                if not contains(collisions, function(o) return o.collider1 == other and o.collider2 == c or o.collider1 == c and o.collider2 == other end) then
                    local collision = collider.types[c.type .. "_" .. other.type].collide(c, other)
                    if collision then
                        table.insert(collisions, collision)
                    end
                end
            end
        end
    end
    for i = 1, #collisions do
        local collision = collisions[i]
        collision:resolve()
    end
end

function collider.get_height(x, z)
    return 0
end

function collider.get_normal(x, z)
    return vector(0, 1, 0)
end

function collider.set_height(height)
    collider.get_height = height
end

function collider.set_normal(normal)
    collider.get_normal = normal
end

function collider.transformInertiaTensor(q, inverse_inertia, rotmat)
    local rotmat = rotmat or matrix(3, 3)
    rotmat[1] = q:rotate(vector(1, 0, 0))
    rotmat[2] = q:rotate(vector(0, 1, 0))
    rotmat[3] = q:rotate(vector(0, 0, 1))

    local t4 = rotmat[1][1] * inverse_inertia[1][1] + rotmat[1][2] * inverse_inertia[2][1] + rotmat[1][3] * inverse_inertia[3][1]
    local t9 = rotmat[1][1] * inverse_inertia[1][2] + rotmat[1][2] * inverse_inertia[2][2] + rotmat[1][3] * inverse_inertia[3][2]
    local t14 = rotmat[1][1] * inverse_inertia[1][3] + rotmat[1][2] * inverse_inertia[2][3] + rotmat[1][3] * inverse_inertia[3][3]
    local t28 = rotmat[2][1] * inverse_inertia[1][1] + rotmat[2][2] * inverse_inertia[2][1] + rotmat[2][3] * inverse_inertia[3][1]
    local t33 = rotmat[2][1] * inverse_inertia[1][2] + rotmat[2][2] * inverse_inertia[2][2] + rotmat[2][3] * inverse_inertia[3][2]
    local t38 = rotmat[2][1] * inverse_inertia[1][3] + rotmat[2][2] * inverse_inertia[2][3] + rotmat[2][3] * inverse_inertia[3][3]
    local t52 = rotmat[3][1] * inverse_inertia[1][1] + rotmat[3][2] * inverse_inertia[2][1] + rotmat[3][3] * inverse_inertia[3][1]
    local t57 = rotmat[3][1] * inverse_inertia[1][2] + rotmat[3][2] * inverse_inertia[2][2] + rotmat[3][3] * inverse_inertia[3][2]
    local t62 = rotmat[3][1] * inverse_inertia[1][3] + rotmat[3][2] * inverse_inertia[2][3] + rotmat[3][3] * inverse_inertia[3][3]

    inverse_inertia[1][1] = t4 * rotmat[1][1] + t9 * rotmat[1][2] + t14 * rotmat[1][3]
    inverse_inertia[1][2] = t4 * rotmat[2][1] + t9 * rotmat[2][2] + t14 * rotmat[2][3]
    inverse_inertia[1][3] = t4 * rotmat[3][1] + t9 * rotmat[3][2] + t14 * rotmat[3][3]
    inverse_inertia[2][1] = t28 * rotmat[1][1] + t33 * rotmat[1][2] + t38 * rotmat[1][3]
    inverse_inertia[2][2] = t28 * rotmat[2][1] + t33 * rotmat[2][2] + t38 * rotmat[2][3]
    inverse_inertia[2][3] = t28 * rotmat[3][1] + t33 * rotmat[3][2] + t38 * rotmat[3][3]
    inverse_inertia[3][1] = t52 * rotmat[1][1] + t57 * rotmat[1][2] + t62 * rotmat[1][3]
    inverse_inertia[3][2] = t52 * rotmat[2][1] + t57 * rotmat[2][2] + t62 * rotmat[2][3]
    inverse_inertia[3][3] = t52 * rotmat[3][1] + t57 * rotmat[3][2] + t62 * rotmat[3][3]
end

function collider.calculateTransformMatrix(position, orientation)
    local matrix = matrix(4, 4)
    matrix[1] = { orientation.x * 2, orientation.y * 2, orientation.z * 2, 0 }
    matrix[2] = { 1 - 2 * orientation.y^2 - 2 * orientation.z^2, 2 * orientation.x * orientation.y - 2 * orientation.z * orientation.w, 2 * orientation.x * orientation.z + 2 * orientation.y * orientation.w, 0 }
    matrix[3] = { 2 * orientation.x * orientation.y + 2 * orientation.z * orientation.w, 1 - 2 * orientation.x^2 - 2 * orientation.z^2, 2 * orientation.y * orientation.z - 2 * orientation.x * orientation.w, 0 }
    matrix[4] = { 2 * orientation.x * orientation.z - 2 * orientation.y * orientation.w, 2 * orientation.y * orientation.z + 2 * orientation.x * orientation.w, 1 - 2 * orientation.x^2 - 2 * orientation.y^2, 0 }

    matrix[1][4] = position.x
    matrix[2][4] = position.y
    matrix[3][4] = position.z
    matrix[4][4] = 1

    return matrix
end

require("physics.collider.sphere")(collider)
require("physics.collider.cube")(collider)
require("physics.collider.cube_sphere")(collider)

return collider