local class = import("class", "local")
local events = import("events", "local")
local gui = import("gui", "local")
local color = import("color", "local")

gui.layer_count = math.max(gui.layer_count, 2)

---@class State : Class
---@field on_enter fun(self: State, args: any)?
---@field on_leave fun(self: State, args: any)?
---@field on_prepare_leave fun(self: State)?
---@field on_click fun(self:State, pos : Vector)?
---@field on_key_down fun(self:State, key : number)?
---@field on_key_up fun(self:State, key : number)?
---@field on_wnd_proc fun(self:State, hwnd : number, msg : number, wParam : number, lParam : number)?
---@field on_resize fun(self:State, hwnd : number, x : number, y : number, w : number, h : number)?
---@field on_initialize fun(self:State)?
---@field on_update fun(self:State, dT : number)?
---@field on_render fun(self:State, layer : number)?
---@field on_activated fun(self:State)?
---@field on_deactivated fun(self:State)?
---@field on_suspending fun(self:State)?
---@field on_resuming fun(self:State)?
---@field on_display_change fun(self:State)?
---@field on_window_size_changed fun(self:State, w : number, h : number)?
---@field on_window_moved fun(self:State, x : number, y : number)?
---@field on_screen_mode_changed fun(self:State, fullscreen : boolean)?
---@field on_device_lost fun(self:State)?
---@field on_device_restored fun(self:State)?
---@field on_shutdown fun(self:State)?
local State = class("State")

function State:__init()
    for k, v in pairs(events.list) do
        self["__on_" .. v] = function(...)
            self["on_" .. v](self, ...)
        end
    end
end
for k, v in pairs(events.list) do
    State["on_" .. v] = function() end
end
function State:enter(args)
    if self.on_enter then
        self:on_enter(args)
    end
    for k, v in pairs(events.list) do
        events[v] += self["__on_" .. v]
    end
end
function State:leave(args)
    if self.on_leave then
        self:on_leave(args)
    end
    for k, v in pairs(events.list) do
        events[v] -= self["__on_" .. v]
    end
end
function State:prepare_leave()
    if self.on_prepare_leave then
        self:on_prepare_leave()
    end
    events.update -= self["__on_update"]
end

---@class StateTransition : Class
local StateTransition = class("StateTransition")

function StateTransition:__init(effect, duration, callback)
    self.callback = callback
    if not effect or effect == "instant" or not duration or duration == 0 then
        return self:transition()
    end
    self.effect = effect
    self.duration = duration
    self:start()
end

function StateTransition:transition()
    if self.callback then
        local s, e = pcall(self.callback)
        if not s then
            printf("!> StateTransition Error: %s", e)
        end
        self.callback = nil
    end
end

function StateTransition:update(dT)
    self.timer += dT
    if self.effect == "fade" then
        if self.timer > self.duration * .5 then
            return self:transition()
        end
    end
    if self.timer >= self.duration then
        self:finish()
        return self:transition()
    end
end

function StateTransition:render(layer)
    if layer ~= math.huge then return end
    local t = self.timer / self.duration
    if self.effect == "fade" then
        if self.timer < self.duration * .5 then
            local a = t * 2
            gui.no_stroke()
            gui.fill(color(0, 0, 0, a))
            gui.rectangle(0, 0, gui.width, gui.height)
        else
            local a = (1 - t) * 2
            gui.no_stroke()
            gui.fill(color(0, 0, 0, a))
            gui.rectangle(0, 0, gui.width, gui.height)
        end
    end
end

function StateTransition:start()
    self.timer = 0
    self.on_update = function(dT)
        self:update(dT)
    end
    events.update += self.on_update
    self.on_render = function(layer)
        self:render(layer)
    end
    events.render += self.on_render
end

function StateTransition:finish()
    events.update -= self.on_update
    events.render -= self.on_render
end

---@class StateManager : Class
local StateManager = class("StateManager")

StateManager.states = {}
StateManager.current = nil

---@return State
function StateManager:new(name)
    local state = State(name)
    self.states[name] = state
    return state
end

function StateManager:has(name)
    return self.states[name] ~= nil
end

function StateManager:get(name)
    return self.states[name]
end

function StateManager:remove(name)
    self.states[name] = nil
end

function StateManager:enter(name, args)
    if self.current then
        self.current:leave(args)
    end
    print("?> Entering state: " .. name)
    self.current = self.states[name]
    local s, e = pcall(self.current.enter, self.current, args)
    if not s then
        printf("!> StateManager Error: %s", e)
    end
    if not gui.running then
        gui.run()
    end
end

function StateManager:transition(name, effect, duration, args)
    if self.current then
        self.current:prepare_leave()
    end
    StateTransition(effect, duration, function()
        self:enter(name, args)
    end)
end

return StateManager