"""Persistent session-store adapters for research and QA history."""

from __future__ import annotations

import json
from datetime import datetime, timedelta
from typing import Any, Protocol

from app.config import settings
from app.infrastructure.mysql_client import MySQLClient
from app.infrastructure.redis_client import RedisClient
from app.utils.logger import get_logger

logger = get_logger(__name__)


class ResearchSessionStore(Protocol):
    async def load_session(self, session_id: str, user_id: str) -> list[dict[str, Any]]:
        ...

    async def save_session(
        self,
        session_id: str,
        user_id: str,
        question: str,
        answer: str,
        refs: list[dict[str, Any]],
    ) -> None:
        ...


class RedisResearchSessionStore:
    def __init__(self, redis_client: RedisClient, *, ttl_seconds: int) -> None:
        self._redis = redis_client
        self._ttl_seconds = ttl_seconds

    async def load_session(self, session_id: str, user_id: str) -> list[dict[str, Any]]:
        key = self._key(session_id, user_id)
        raw = await self._redis.raw.get(key)
        if not raw:
            return []
        return json.loads(raw)

    async def save_session(
        self,
        session_id: str,
        user_id: str,
        question: str,
        answer: str,
        refs: list[dict[str, Any]],
    ) -> None:
        key = self._key(session_id, user_id)
        raw = await self._redis.raw.get(key)
        history: list[dict[str, Any]] = json.loads(raw) if raw else []
        history.append(
            {"question": question, "answer": answer, "references": refs}
        )
        await self._redis.raw.setex(
            key,
            self._ttl_seconds,
            json.dumps(history, ensure_ascii=False),
        )

    @staticmethod
    def _key(session_id: str, user_id: str) -> str:
        return f"research:session:{user_id}:{session_id}"


class MySQLResearchSessionStore:
    def __init__(self, mysql_client: MySQLClient, *, ttl_seconds: int) -> None:
        self._mysql = mysql_client
        self._ttl_seconds = ttl_seconds

    async def load_session(self, session_id: str, user_id: str) -> list[dict[str, Any]]:
        row = await self._mysql.fetch_one(
            """
            SELECT history
            FROM research_sessions
            WHERE session_id = %s
              AND user_id = %s
              AND (expires_at IS NULL OR expires_at > UTC_TIMESTAMP())
            LIMIT 1
            """,
            (session_id, user_id),
        )
        if not row:
            return []

        history = row.get("history")
        if isinstance(history, str):
            return json.loads(history)
        if isinstance(history, (bytes, bytearray)):
            return json.loads(history.decode("utf-8"))
        return list(history or [])

    async def save_session(
        self,
        session_id: str,
        user_id: str,
        question: str,
        answer: str,
        refs: list[dict[str, Any]],
    ) -> None:
        history = await self.load_session(session_id, user_id)
        history.append(
            {"question": question, "answer": answer, "references": refs}
        )
        payload = json.dumps(history, ensure_ascii=False)
        expires_at = datetime.utcnow() + timedelta(seconds=self._ttl_seconds)
        await self._mysql.execute(
            """
            INSERT INTO research_sessions (session_id, user_id, history, expires_at)
            VALUES (%s, %s, %s, %s) AS new_values
            ON DUPLICATE KEY UPDATE
                history = new_values.history,
                expires_at = new_values.expires_at,
                updated_at = CURRENT_TIMESTAMP
            """,
            (session_id, user_id, payload, expires_at),
        )


def build_research_session_store(
    *,
    redis_client: RedisClient,
    mysql_client: MySQLClient | None,
) -> ResearchSessionStore:
    """Build the configured session store, with Redis fallback if needed."""
    if settings.research_session_backend.lower() == "mysql":
        if mysql_client is not None:
            return MySQLResearchSessionStore(
                mysql_client,
                ttl_seconds=settings.research_session_ttl_seconds,
            )
        logger.warning("mysql_session_store_unavailable_fallback_to_redis")

    return RedisResearchSessionStore(
        redis_client,
        ttl_seconds=settings.research_session_ttl_seconds,
    )