from __future__ import annotations

from unittest.mock import AsyncMock, MagicMock

import pytest

from app.config import settings
from app.infrastructure.session_store import (
    MySQLResearchSessionStore,
    RedisResearchSessionStore,
    build_research_session_store,
)


class FakeMySQLClient:
    def __init__(self) -> None:
        self.rows: dict[tuple[str, str], str] = {}

    async def fetch_one(self, query: str, params: tuple[str, str]):
        del query
        payload = self.rows.get((params[0], params[1]))
        if payload is None:
            return None
        return {"history": payload}

    async def execute(self, query: str, params):
        del query
        session_id, user_id, payload, _expires_at = params
        self.rows[(session_id, user_id)] = payload
        return 1


@pytest.mark.asyncio
async def test_mysql_session_store_isolated_by_user() -> None:
    store = MySQLResearchSessionStore(FakeMySQLClient(), ttl_seconds=300)

    await store.save_session("session-1", "user-a", "q1", "a1", [])
    await store.save_session("session-1", "user-b", "q2", "a2", [])

    user_a_history = await store.load_session("session-1", "user-a")
    user_b_history = await store.load_session("session-1", "user-b")

    assert len(user_a_history) == 1
    assert len(user_b_history) == 1
    assert user_a_history[0]["question"] == "q1"
    assert user_b_history[0]["question"] == "q2"


@pytest.mark.asyncio
async def test_redis_session_store_isolated_by_user() -> None:
    redis_client = MagicMock()
    redis_client.raw = MagicMock()

    payloads: dict[str, str] = {}

    async def fake_get(key: str):
        return payloads.get(key)

    async def fake_setex(key: str, _ttl: int, value: str):
        payloads[key] = value

    redis_client.raw.get = AsyncMock(side_effect=fake_get)
    redis_client.raw.setex = AsyncMock(side_effect=fake_setex)

    store = RedisResearchSessionStore(redis_client, ttl_seconds=300)
    await store.save_session("session-1", "user-a", "q1", "a1", [])
    await store.save_session("session-1", "user-b", "q2", "a2", [])

    user_a_history = await store.load_session("session-1", "user-a")
    user_b_history = await store.load_session("session-1", "user-b")

    assert user_a_history[0]["answer"] == "a1"
    assert user_b_history[0]["answer"] == "a2"
    assert redis_client.raw.setex.await_count == 2


def test_build_research_session_store_falls_back_to_redis(monkeypatch: pytest.MonkeyPatch) -> None:
    redis_client = MagicMock()
    monkeypatch.setattr(settings, "research_session_backend", "mysql")

    store = build_research_session_store(
        redis_client=redis_client,
        mysql_client=None,
    )

    assert isinstance(store, RedisResearchSessionStore)