import asyncio
import logging

from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect
from fastapi.templating import Jinja2Templates
from fastapi.websockets import WebSocketState
from sqlmodel import select

from rvpc.celery.celery_app import count_cards
from rvpc.db import SessionDep
from rvpc.managers.manager import WSManager
from rvpc.models.models import Location
from rvpc.routers.utils import smart


def make_map(prefix: str, templates: Jinja2Templates):
    map = APIRouter(prefix=prefix)
    map_manager = WSManager(id=1, connections=set())
    data_manager = WSManager(id=2, connections=set())

    @map.get("/")
    @smart("map/full_map.html")
    async def render_map(session: SessionDep):
        locs = session.exec(select(Location)).all()

        return {"height": "h-full", "locations": locs}

    @map.get("/locations", response_model=list[Location])
    async def return_locs(session: SessionDep):
        locs = session.exec(select(Location)).all()
        return locs

    @map.websocket("/counts")
    async def count_ws(ws: WebSocket, session: SessionDep):
        await ws.accept()
        data_manager.connections.add(ws)

        count_q = await ws.app.state.channel.get_queue(
            "loc_count",
        )
        count_task = count_cards.delay()
        try:
            while True:
                count_msg = await count_q.get(fail=False)

                if count_msg:
                    await count_msg.ack()
                    for conn in map_manager.connections:
                        if conn.application_state == WebSocketState.CONNECTED:
                            await conn.send_text(count_msg.body.decode())

        except asyncio.CancelledError:
            logging.info("WebSocket task cancelled (reload/shutdown)")
            raise

        except WebSocketDisconnect:
            logging.info("Disconnected")
            if ws in map_manager.connections:
                map_manager.connections.remove(ws)

        finally:
            count_task.revoke()
            if ws in map_manager.connections:
                map_manager.connections.remove(ws)

    @map.get("/poll")
    async def poll(request: Request):
        count_cards.delay()

    @map.websocket("/locs")
    async def locs_ws(ws: WebSocket, session: SessionDep):
        await ws.accept()
        map_manager.connections.add(ws)

        capture_q = await ws.app.state.channel.get_queue(
            "capture_status",
        )

        new_q = await ws.app.state.channel.get_queue(
            "new_card",
        )

        try:
            while True:
                cap_msg = await capture_q.get(fail=False)
                new_msg = await new_q.get(fail=False)

                if new_msg:
                    await new_msg.ack()

                    for conn in map_manager.connections:
                        if conn.application_state == WebSocketState.CONNECTED:
                            await conn.send_text(new_msg.body.decode())

                if cap_msg:
                    await cap_msg.ack()

                    for conn in map_manager.connections:
                        if conn.application_state == WebSocketState.CONNECTED:
                            msg = cap_msg.body.decode()
                            await conn.send_text(msg)

        except asyncio.CancelledError:
            logging.info("WebSocket task cancelled (reload/shutdown)")
            raise

        except WebSocketDisconnect:
            logging.info("Disconnected")
            if ws in map_manager.connections:
                map_manager.connections.remove(ws)

        finally:
            if ws in map_manager.connections:
                map_manager.connections.remove(ws)

    return map
