"""Redis-based caching for embedding vectors and search results.

Design
------
Embedding cache
    Key:  ``emb:{sha256(text)}``
    Value: JSON-serialised ``list[float]``
    TTL:  1 hour  (embeddings are deterministic and model-stable)

Search result cache
    Key:  ``search:{user_id}:{sha256(query+filters+page+page_size+llm)}``
    Value: JSON-serialised search result dict
    TTL:  2 minutes  (short enough that permission updates are visible quickly;
          permissions are append-only so cached results are always a safe subset)

基于 Redis 的查询缓存模块。
提供两层缓存：
1. 向量缓存 —— 相同文本的 embedding 结果缓存 1 小时，避免重复调用向量化 API
2. 搜索结果缓存 —— 相同用户+查询+筛选条件的搜索结果缓存 2 分钟，
   TTL 足够短以确保权限变更后能及时生效
"""

from __future__ import annotations

import hashlib
import json
from typing import Any

import redis.asyncio as aioredis

from app.utils.logger import get_logger

logger = get_logger(__name__)

_EMBEDDING_TTL = 3600      # 1 hour
_SEARCH_TTL = 120          # 2 minutes
_EMBEDDING_PREFIX = "emb:"
_SEARCH_PREFIX = "search:"


def _sha256(text: str) -> str:
    return hashlib.sha256(text.encode("utf-8")).hexdigest()[:32]


# ---------------------------------------------------------------------------
# Embedding cache
# ---------------------------------------------------------------------------


async def get_cached_embedding(redis: aioredis.Redis, text: str) -> list[float] | None:
    """Return a cached embedding vector, or ``None`` on cache miss."""
    try:
        key = f"{_EMBEDDING_PREFIX}{_sha256(text)}"
        raw = await redis.get(key)
        if raw:
            return json.loads(raw)
    except Exception as exc:
        logger.warning("embedding_cache_get_error", error=str(exc))
    return None


async def set_cached_embedding(
    redis: aioredis.Redis, text: str, vector: list[float]
) -> None:
    """Store an embedding vector in Redis."""
    try:
        key = f"{_EMBEDDING_PREFIX}{_sha256(text)}"
        await redis.setex(key, _EMBEDDING_TTL, json.dumps(vector))
    except Exception as exc:
        logger.warning("embedding_cache_set_error", error=str(exc))


# ---------------------------------------------------------------------------
# Search result cache
# ---------------------------------------------------------------------------


def _search_cache_key(
    user_id: str,
    query: str,
    filters: dict,
    page: int,
    page_size: int | None = None,
    llm: bool = False,
) -> str:
    fingerprint = json.dumps(
        {"q": query, "f": filters, "p": page, "ps": page_size, "llm": llm},
        sort_keys=True,
        ensure_ascii=False,
    )
    return f"{_SEARCH_PREFIX}{user_id}:{_sha256(fingerprint)}"


async def get_cached_search(
    redis: aioredis.Redis,
    user_id: str,
    query: str,
    filters: dict,
    page: int,
    page_size: int | None = None,
    llm: bool = False,
) -> dict[str, Any] | None:
    """Return a cached search result dict, or ``None`` on miss."""
    try:
        key = _search_cache_key(user_id, query, filters, page, page_size, llm)
        raw = await redis.get(key)
        if raw:
            logger.debug("search_cache_hit", user_id=user_id, query=query[:40])
            return json.loads(raw)
    except Exception as exc:
        logger.warning("search_cache_get_error", error=str(exc))
    return None


async def set_cached_search(
    redis: aioredis.Redis,
    user_id: str,
    query: str,
    filters: dict,
    page: int,
    result: dict[str, Any],
    page_size: int | None = None,
    llm: bool = False,
) -> None:
    """Store a search result dict in Redis."""
    try:
        key = _search_cache_key(user_id, query, filters, page, page_size, llm)
        await redis.setex(key, _SEARCH_TTL, json.dumps(result, ensure_ascii=False))
        logger.debug("search_cache_set", user_id=user_id, query=query[:40])
    except Exception as exc:
        logger.warning("search_cache_set_error", error=str(exc))


async def invalidate_search_cache(redis: aioredis.Redis, user_id: str) -> None:
    """Delete all cached search results for a user (call after permission update).

    权限变更后清除该用户的所有搜索缓存。
    """
    try:
        pattern = f"{_SEARCH_PREFIX}{user_id}:*"
        cursor = 0
        deleted = 0
        while True:
            cursor, keys = await redis.scan(cursor, match=pattern, count=100)
            if keys:
                # 使用 UNLINK 代替 DELETE：UNLINK 在后台线程异步释放内存，
                # 不会阻塞 Redis 主线程，对大 key 更友好
                # Use UNLINK instead of DELETE for non-blocking async memory reclaim
                await redis.unlink(*keys)
                deleted += len(keys)
            if cursor == 0:
                break
        if deleted:
            logger.info("search_cache_invalidated", user_id=user_id, keys=deleted)
    except Exception as exc:
        logger.warning("search_cache_invalidate_error", error=str(exc))
