local vector = require("util.geometry.vector")
local quaternion = require("util.geometry.quaternion")

return function(collider)
    collider.register("sphere", {
        radius = function(self, facing)
            local dir = (facing - self.position):normalize()
            if dir.x < dir.y and dir.x < dir.z then
                return self.scale.x / 2
            end
            if dir.y < dir.x and dir.y < dir.z then
                return self.scale.y / 2
            end
            return self.scale.z / 2
        end,
        distance = function(self, position)
            return self.position:distance(position) - self:radius(position)
        end,
        collide = function(self, sphere)
            local direction = self.position - sphere.position
            local distance = direction:length()
            local overlap = self:radius(sphere.position) + sphere:radius(self.position) - distance
            if overlap < 0 then
                return
            end
            sphere.dormant = false
            return {
                point = self.position + direction * 0.5,
                normal = direction:normalize(),
                penetration = overlap,
                collider1 = self,
                collider2 = sphere,
                resolve = self.resolve,
            }
        end,
        resolve = function(collision)
            -- move the colliders apart
            collision.collider1.position += collision.normal * collision.penetration * 0.5
            collision.collider2.position -= collision.normal * collision.penetration * 0.5

            -- calculate the impulse
            local relative_velocity = collision.collider1.velocity - collision.collider2.velocity
            local normal_velocity = relative_velocity:dot(collision.normal)

            local e = math.min(collision.collider1.elasticity, collision.collider2.elasticity)
            local j = -(1 + e) * normal_velocity / (collision.collider1.inverse_mass + collision.collider2.inverse_mass)
            local impulse = collision.normal * j

            collision.collider1.velocity += impulse * collision.collider1.inverse_mass
            collision.collider2.velocity -= impulse * collision.collider2.inverse_mass

            -- calculate the angular impulse
            local r1 = collision.point - collision.collider1.position
            local r2 = collision.point - collision.collider2.position

            local torque1 = r1:cross(impulse)
            local torque2 = r2:cross(impulse)

            collision.collider1.rotation += quaternion{
                axis = r1:normalize(),
                angle = torque1:length()
            }

            collision.collider2.rotation -= quaternion{
                axis = r2:normalize(),
                angle = torque2:length()
            }
        end,
        ground = function(self, dt)
            -- collision with the ground
            local ground_height = self.position.y - collider.get_height(self.position.x, self.position.z) - self.scale.y / 2

            -- collider is intersecting the ground
            if ground_height < -0.01 then
                -- snap up to the ground
                self.position.y += -ground_height
                self.grounded = true

                -- collide, applying damping
                local normal = collider.get_normal(self.position.x, self.position.z)
                self.velocity = self.elasticity * (self.velocity - 2 * self.velocity:dot(normal) * normal)

                -- spin
                self.rotation += quaternion{
                    axis = self.velocity:cross(normal):cross(normal),
                    angle = self.velocity:length() / self.scale.y / 2
                }
                return
            end

            -- collider is on the ground
            if math.eq(ground_height, 0, 0.01) then
                -- snap to the ground
                self.position.y += -ground_height
                self.grounded = true

                -- apply normal force
                local normal = collider.get_normal(self.position.x, self.position.z)
                self.velocity -= self.velocity + self.velocity:cross(normal):cross(normal)

                -- spin
                self.rotation = quaternion{
                    axis = -self.velocity:normalize(),
                    angle = self.velocity:length() / self.scale.y / 2
                }
                return
            end

            self.grounded = false
        end,
        forward_area = function(self)
            local radius = self:radius(self.position + self.velocity)
            return math.pi * radius * radius
        end,
        drag_coeff = function(self)
            return 0.42
        end,
        drag = function(self)
            return self:forward_area() * self.drag_coeff()
        end,
        friction = 0.2,
        elasticity = 0.9,
    })
end