from fastapi import Depends, Form
from typing import Annotated
from mistralai import Mistral, ModerationResponse
from rvpc.auth.utils import require_env

api_key = require_env("MISTRAL_API_KEY")


def get_client():
    with Mistral(api_key=api_key) as client:
        yield client


Moderation = Annotated[Mistral, Depends(get_client)]


def build_inputs(sender: str, recipient: str, msg: str):
    return [{"role": "user", "content": c} for c in [sender, recipient, msg]]


def get_evaluation(
    moderation: Mistral, messages: list[dict[str, str]]
) -> ModerationResponse:
    return moderation.classifiers.moderate_chat(
        model="mistral-moderation-latest", inputs=messages
    )


def check_evaluation(mod_res: ModerationResponse):
    if any(mod_res.results[0].categories.values()):
        return False
    return True


def moderate(
    sender: Annotated[str, Form()],
    recipient: Annotated[str, Form()],
    msg: Annotated[str, Form()],
    moderation: Moderation,
):
    inputs = build_inputs(sender, recipient, msg)
    mod_res = get_evaluation(moderation=moderation, messages=inputs)

    return check_evaluation(mod_res=mod_res)


Moderated = Annotated[bool, Depends(moderate)]
