local vector = require("geometry.vector")
local quaternion = require("geometry.quaternion")

---@class Matrix : Class
---@field rows number
---@field columns number
local Matrix = {}

Matrix.__index = Matrix

function Matrix:__add(m)
    local result = Matrix(self.rows, self.columns)
    for i = 1, self.rows do
        for j = 1, self.columns do
            result[i][j] = self[i][j] + m[i][j]
        end
    end
    return result
end

function Matrix:__sub(m)
    local result = Matrix(self.rows, self.columns)
    for i = 1, self.rows do
        for j = 1, self.columns do
            result[i][j] = self[i][j] - m[i][j]
        end
    end
    return result
end

function Matrix:__mul(m)
    if type(m) == "number" then
        local result = Matrix(self.rows, self.columns)
        for i = 1, self.rows do
            for j = 1, self.columns do
                result[i][j] = self[i][j] * m
            end
        end
        return result
    else
        local mt = getmetatable(m)
        if mt == Matrix then
            local result = Matrix(self.rows, m.columns)
            for i = 1, self.rows do
                for j = 1, m.columns do
                    local sum = 0
                    for k = 1, self.columns do
                        sum = sum + self[i][k] * m[k][j]
                    end
                    result[i][j] = sum
                end
            end
            return result
        elseif mt == vector then
            local result = vector(0, 0, 0)
            for i = 1, self.rows do
                local sum = 0
                for j = 1, self.columns do
                    sum = sum + self[i][j] * m[j]
                end
                result[i] = sum
            end
            return result
        elseif mt == quaternion then
            local result = quaternion(0, 0, 0, 0)
            local i_to_x = {
                [1] = "w",
                [2] = "x",
                [3] = "y",
                [4] = "z"
            }
            for i = 1, self.rows do
                for j = 1, self.columns do
                    result = result + self[i][j] * m[i_to_x[j]]
                end
            end
            return result
        end
    end
end

function Matrix:__div(m)
    local result = Matrix(self.rows, self.columns)
    for i = 1, self.rows do
        for j = 1, self.columns do
            result[i][j] = self[i][j] / m
        end
    end
    return result
end

function Matrix:__unm()
    local result = Matrix(self.rows, self.columns)
    for i = 1, self.rows do
        for j = 1, self.columns do
            result[i][j] = -self[i][j]
        end
    end
    return result
end

function Matrix:__len()
    return math.sqrt(self[1][1] * self[1][1] + self[2][2] * self[2][2] + self[3][3] * self[3][3])
end

function Matrix:__eq(m)
    if self.rows ~= m.rows or self.columns ~= m.columns then
        return false
    end
    for i = 1, self.rows do
        for j = 1, self.columns do
            if self[i][j] ~= m[i][j] then
                return false
            end
        end
    end
    return true
end

function Matrix:__tostring()
    local fmt = "[%f"
    for i = 2, self.columns do
        fmt = fmt .. ", %f"
    end
    fmt = fmt .. "]"
    local result = "["
    for i = 1, self.rows do
        result = result .. string.format(fmt, self[i][1], table.unpack(self[i], 2))
        if i < self.rows then
            result = result .. ", "
        end
    end
    return result .. "]"
end

function Matrix:transpose()
    local result = Matrix(self.columns, self.rows)
    for i = 1, self.rows do
        for j = 1, self.columns do
            result[j][i] = self[i][j]
        end
    end
    return result
end

function Matrix:inverse()
    if self.rows == 2 and self.columns == 2 then
        local det = self[1][1] * self[2][2] - self[1][2] * self[2][1]
        if det == 0 then
            return nil
        end
        return Matrix(
            {
                {self[2][2], -self[1][2]},
                {-self[2][1], self[1][1]}
            }
        ) / det
    elseif self.rows == 3 and self.columns == 3 then
        local det = self[1][1] * (self[2][2] * self[3][3] - self[2][3] * self[3][2]) -
                        self[1][2] * (self[2][1] * self[3][3] - self[2][3] * self[3][1]) +
                        self[1][3] * (self[2][1] * self[3][2] - self[2][2] * self[3][1])
        if det == 0 then
            return nil
        end
        return Matrix(
            {
                {self[2][2] * self[3][3] - self[2][3] * self[3][2], self[1][3] * self[3][2] - self[1][2] * self[3][3], self[1][2] * self[2][3] - self[1][3] * self[2][2]},
                {self[2][3] * self[3][1] - self[2][1] * self[3][3], self[1][1] * self[3][3] - self[1][3] * self[3][1], self[1][3] * self[2][1] - self[1][1] * self[2][3]},
                {self[2][1] * self[3][2] - self[2][2] * self[3][1], self[1][2] * self[3][1] - self[1][1] * self[3][2], self[1][1] * self[2][2] - self[1][2] * self[2][1]},
            }
        ) / det
    end
end

function Matrix.identity(rows, columns)
    local result = Matrix(rows, columns)
    for i = 1, math.min(rows, columns) do
        result[i][i] = 1
    end
    return result
end

function Matrix:copy()
    local result = Matrix(self.rows, self.columns)
    for i = 1, self.rows do
        for j = 1, self.columns do
            result[i][j] = self[i][j]
        end
    end
    return result
end

function Matrix:to_vector()
    if self.rows == 3 and self.columns == 1 then
        return vector(self[1][1], self[2][1], self[3][1])
    elseif self.rows == 1 and self.columns == 3 then
        return vector(self[1][1], self[1][2], self[1][3])
    end
end

function Matrix:to_quaternion()
    if self.rows == 4 and self.columns == 1 then
        return quaternion(self[1][1], self[2][1], self[3][1], self[4][1])
    end
end

function Matrix:push(v)
    if self.columns == 1 then
        self.rows = self.rows + 1
        self[self.rows] = {v[1], v[2], v[3]}
    elseif self.rows == 1 then
        self.columns = self.columns + 1
        self[1][self.columns] = v[1]
        self[2][self.columns] = v[2]
        self[3][self.columns] = v[3]
    end
end

function Matrix:pop()
    if self.columns == 1 then
        self.rows = self.rows - 1
    elseif self.rows == 1 then
        self.columns = self.columns - 1
    end
end

function Matrix:push_quaternion(q)
    if self.columns == 1 then
        self.rows = self.rows + 1
        self[self.rows] = {q.x, q.y, q.z, q.w}
    end
end

function Matrix:pop_quaternion()
    if self.columns == 1 then
        self.rows = self.rows - 1
    end
end

function Matrix:push_matrix(m)
    for i = 1, m.rows do
        self:push(m[i])
    end
end

function Matrix:pop_matrix(m)
    for i = 1, m.rows do
        self:pop()
    end
end

function Matrix:push_vector(v)
    if self.columns == 1 then
        self.rows = self.rows + 1
        self[self.rows] = {v.x, v.y, v.z}
    elseif self.rows == 1 then
        self.columns = self.columns + 1
        self[1][self.columns] = v.x
        self[2][self.columns] = v.y
        self[3][self.columns] = v.z
    end
end

function Matrix:pop_vector()
    if self.columns == 1 then
        self.rows = self.rows - 1
    elseif self.rows == 1 then
        self.columns = self.columns - 1
    end
end

setmetatable(Matrix, {
    __call = function(self, rows, columns)
        local v = setmetatable({
            rows = rows,
            columns = columns
        }, Matrix)
        for i = 1, rows do
            v[i] = {}
            for j = 1, columns do
                v[i][j] = 0
            end
        end
        return v
    end
})

return Matrix