"""Embedding client using an OpenAI-compatible embeddings API.

文本向量化客户端封装模块。
通过 OpenAI 兼容的 Embeddings API 将文本转换为稠密向量，
支持批量编码和单条查询编码，默认使用 bge-m3 模型（1024 维）。
"""

from __future__ import annotations

import httpx
from openai import AsyncOpenAI

from app.config import settings
from app.utils.logger import get_logger

logger = get_logger(__name__)


class EmbeddingClient:
    """Async wrapper for computing text embeddings via an OpenAI-compatible API.

    异步文本向量化客户端，支持批量和单条文本的向量计算。
    """

    def __init__(
        self,
        *,
        base_url: str | None = None,
        api_key: str | None = None,
        model: str | None = None,
        dimensions: int | None = None,
    ) -> None:
        self._model = model or settings.embedding_model
        self._dimensions = dimensions or settings.embedding_dimensions
        # Use trust_env=False to bypass system proxy settings so the client
        # connects directly to the embedding endpoint (e.g. DashScope).
        # 设置 10 秒超时（连接 5s + 读取 10s），避免外部 API 无响应时长时间阻塞搜索请求。
        self._client = AsyncOpenAI(
            base_url=base_url or settings.embedding_base_url,
            api_key=api_key or settings.embedding_api_key,
            timeout=httpx.Timeout(10.0, connect=5.0),
            http_client=httpx.AsyncClient(
                trust_env=False,
                timeout=httpx.Timeout(10.0, connect=5.0),
            ),
        )

    async def close(self) -> None:
        await self._client.close()

    async def embed_texts(
        self,
        texts: list[str],
        *,
        model: str | None = None,
    ) -> list[list[float]]:
        """Embed a batch of texts and return a list of vectors.

        Each vector has *self._dimensions* dimensions (default 1024 for
        ``bge-m3``).

        The OpenAI embeddings API accepts a list of strings so a single
        round-trip is made for the whole batch.
        """
        if not texts:
            return []

        # 调用 Embedding API 并捕获异常，记录日志后向上抛出让调用方处理重试
        # Wrap the API call with error handling; log and re-raise for caller retry
        try:
            response = await self._client.embeddings.create(
                model=model or self._model,
                input=texts,
                dimensions=self._dimensions,
            )
        except Exception as exc:
            logger.error(
                "embedding_api_error",
                model=model or self._model,
                text_count=len(texts),
                error=str(exc),
            )
            raise

        # The response contains an ordered list of embedding objects.
        vectors = [item.embedding for item in response.data]

        # 维度不匹配时自动修正：截断多余维度或补零至目标维度
        # Trim or zero-pad vectors whose dimension differs from the expected size
        for i, vec in enumerate(vectors):
            if len(vec) != self._dimensions:
                logger.warning(
                    "embedding_dim_mismatch",
                    expected=self._dimensions,
                    actual=len(vec),
                    index=i,
                )
                if len(vec) > self._dimensions:
                    # 截断：取前 N 维
                    vectors[i] = vec[: self._dimensions]
                else:
                    # 补零：末尾填充 0.0 至目标维度
                    vectors[i] = vec + [0.0] * (self._dimensions - len(vec))

        return vectors

    async def embed_query(
        self,
        text: str,
        *,
        model: str | None = None,
    ) -> list[float]:
        """Embed a single query text and return the vector.

        This is a convenience wrapper around :meth:`embed_texts` for the
        common single-query use case.
        """
        vectors = await self.embed_texts([text], model=model)
        return vectors[0]
