import base64
import hashlib
import hmac
from typing import Callable

from sqlmodel import Session

from rvpc.auth.tokens.typing import (
    DecodedPayload,
    TokenValidated,
    TokenValidationError,
)
from rvpc.models.models import TokenInfo


def token_payload_validator(
    payload: str,
) -> DecodedPayload | TokenValidationError:
    try:
        d_p = DecodedPayload(*payload.split(":"))
        return d_p
    except ValueError:
        return TokenValidationError("Malformed token")


def token_type_validator(d_p: DecodedPayload, expected_type: str):
    if d_p.token_type != expected_type:
        return TokenValidationError("Wrong token type")
    return d_p


def token_info_validator(
    token_info: TokenInfo | None,
) -> TokenInfo | TokenValidationError:
    if token_info is None:
        return TokenValidationError("Token already used or not found")

    if token_info.expired:
        return TokenValidationError("Token expired", token_info=token_info)
    return token_info


def make_token_sig_validator(key: str):
    def token_sig_validator(
        token: str,
    ) -> str | TokenValidationError:
        try:
            payload_b64, signature = token.split(".")

        except ValueError:
            return TokenValidationError("Malformed token")

        payload_bytes = base64.urlsafe_b64decode(payload_b64.encode())
        payload = payload_bytes.decode()
        expected_sig = hmac.new(
            key.encode(), payload.encode(), hashlib.sha256
        ).hexdigest()

        if not hmac.compare_digest(expected_sig, signature):
            return TokenValidationError("Invalid signature")

        return payload

    return token_sig_validator


def make_generic_validator(
    sig_validator: Callable[[str], str | TokenValidationError],
):
    def generic_validator(
        token: str,
        expected_type: str,
    ) -> DecodedPayload | TokenValidationError:
        payload = sig_validator(token)
        if isinstance(payload, TokenValidationError):
            return payload
        d_p = token_payload_validator(payload=payload)
        if isinstance(d_p, TokenValidationError):
            return d_p
        d_p = token_type_validator(d_p=d_p, expected_type=expected_type)
        if isinstance(d_p, TokenValidationError):
            return d_p
        return d_p

    return generic_validator


def make_validate_token(
    validator: Callable[[str, str], DecodedPayload | TokenValidationError],
    info_getter: Callable[[DecodedPayload, Session], TokenInfo | None],
):
    def validate_token(
        token: str,
        session: Session,
        expected_type: str,
    ) -> TokenValidated | TokenValidationError:
        d_p = validator(token, expected_type)
        if isinstance(d_p, TokenValidationError):
            return d_p
        token_info = info_getter(d_p, session)
        token_info = token_info_validator(token_info=token_info)
        if isinstance(token_info, TokenValidationError):
            return token_info
        return TokenValidated(id=d_p.id, token_info=token_info)

    return validate_token
