"""
向量化服务 —— 批量生成文本嵌入向量，支持并发控制、自动重试和速率限制。
Batch embedding service with retry logic and rate limiting.

搜索查询的向量化结果会缓存到 Redis（TTL 1 小时），避免重复调用外部 API。
"""

from __future__ import annotations

import asyncio
import random
from typing import Any

import redis.asyncio as aioredis

from app.infrastructure.embedding_client import EmbeddingClient
from app.utils.logger import get_logger
from app.utils.query_cache import get_cached_embedding, set_cached_embedding

logger = get_logger(__name__)

# Maximum texts per single API call.
# DashScope text-embedding-v3 limits batches to ≤10 texts.
# Reduce this if your API provider has tighter limits.
DEFAULT_BATCH_SIZE = 6
MAX_RETRIES = 3
RETRY_DELAY = 2.0  # seconds


class EmbeddingService:
    """向量化服务，将文本列表分批发送到 Embedding API 生成向量，支持并发控制和失败重试。

    Generate embeddings in batches with error handling and retry logic.
    单条查询（embed_single）支持 Redis 缓存，避免重复调用外部 API。
    """

    def __init__(
        self,
        client: EmbeddingClient,
        *,
        redis: aioredis.Redis | None = None,
        batch_size: int = DEFAULT_BATCH_SIZE,
        max_retries: int = MAX_RETRIES,
        retry_delay: float = RETRY_DELAY,
        max_concurrency: int = 4,
    ):
        self._client = client
        self._redis = redis
        self._batch_size = batch_size
        self._max_retries = max_retries
        self._retry_delay = retry_delay
        self._semaphore = asyncio.Semaphore(max_concurrency)

    async def embed_chunks(
        self, texts: list[str]
    ) -> list[list[float]]:
        """批量生成嵌入向量：按 batch_size 分批、控制并发数、失败自动重试。

        Generate embeddings for a list of chunk texts.

        Splits into batches, processes with controlled concurrency,
        and retries on failure.

        Args:
            texts: List of text strings to embed.

        Returns:
            List of embedding vectors (same order as input).

        Raises:
            RuntimeError: If any batch fails after all retries.
        """
        if not texts:
            return []

        # Split into batches
        batches = [
            texts[i : i + self._batch_size]
            for i in range(0, len(texts), self._batch_size)
        ]

        logger.info(
            "embedding_start",
            total_texts=len(texts),
            batches=len(batches),
            batch_size=self._batch_size,
        )

        # Process batches with controlled concurrency
        tasks = [
            self._process_batch(batch, batch_idx)
            for batch_idx, batch in enumerate(batches)
        ]
        # return_exceptions=True 防止单个批次失败导致所有并发任务被取消
        # Prevent one batch failure from cancelling all concurrent tasks
        batch_results = await asyncio.gather(*tasks, return_exceptions=True)

        # 逐个检查结果，如有异常则记录日志并抛出第一个错误
        # Check each result; log and raise the first exception encountered
        all_vectors: list[list[float]] = []
        for idx, result in enumerate(batch_results):
            if isinstance(result, BaseException):
                logger.error(
                    "batch_embedding_failed",
                    batch=idx,
                    error=str(result),
                )
                raise RuntimeError(
                    f"Batch {idx} embedding failed: {result}"
                ) from result
            all_vectors.extend(result)

        if len(all_vectors) != len(texts):
            raise RuntimeError(
                f"Embedding count mismatch: expected {len(texts)}, "
                f"got {len(all_vectors)}"
            )

        logger.info("embedding_complete", total_vectors=len(all_vectors))
        return all_vectors

    async def _process_batch(
        self, texts: list[str], batch_idx: int
    ) -> list[list[float]]:
        """Process a single batch with retry logic."""
        async with self._semaphore:
            last_error: Exception | None = None
            for attempt in range(1, self._max_retries + 1):
                try:
                    vectors = await self._client.embed_texts(texts)
                    logger.debug(
                        "batch_embedded",
                        batch=batch_idx,
                        texts=len(texts),
                        attempt=attempt,
                    )
                    return vectors
                except Exception as e:
                    last_error = e
                    logger.warning(
                        "batch_embed_retry",
                        batch=batch_idx,
                        attempt=attempt,
                        error=str(e),
                    )
                    if attempt < self._max_retries:
                        # 指数退避 + 随机抖动，避免多批次同时重试造成 API 拥塞
                        # Exponential backoff with jitter to avoid thundering herd
                        backoff = self._retry_delay * (2 ** (attempt - 1))
                        jitter = random.uniform(0, backoff * 0.1)
                        await asyncio.sleep(backoff + jitter)

            raise RuntimeError(
                f"Batch {batch_idx} failed after {self._max_retries} retries: "
                f"{last_error}"
            )

    async def embed_single(self, text: str) -> list[float]:
        """Embed a single text (convenience for query embedding).

        优先从 Redis 缓存读取（TTL 1 小时），缓存未命中时调用 API 并写入缓存。
        Checks Redis cache first (1h TTL); on miss, calls API and stores result.
        """
        # 1. Try cache
        if self._redis is not None:
            cached = await get_cached_embedding(self._redis, text)
            if cached is not None:
                logger.debug("embedding_cache_hit", text=text[:40])
                return cached

        # 2. Call API
        vector = await self._client.embed_query(text)

        # 3. Store in cache (best-effort)
        if self._redis is not None:
            await set_cached_embedding(self._redis, text, vector)

        return vector
