"""LLM client using an OpenAI-compatible API (e.g. vLLM, Ollama, DashScope).

大语言模型客户端封装模块。
通过 OpenAI 兼容接口与 LLM 服务通信，支持流式对话、JSON 结构化输出
和单轮补全三种调用模式。可通过 extra_body 传递供应商专属参数
（如 DashScope 的 enable_thinking 开关）。
"""

from __future__ import annotations

import asyncio
from typing import Any, AsyncIterator

import httpx
from openai import AsyncOpenAI

from app.config import settings
from app.utils.logger import get_logger

logger = get_logger(__name__)


class LLMClient:
    """Async wrapper around an OpenAI-compatible chat completions API.

    Supports ``extra_body`` for provider-specific parameters such as
    DashScope's ``enable_thinking`` flag (set ``LLM_ENABLE_THINKING=false``
    in environment to disable chain-of-thought for Qwen thinking models).

    异步 LLM 客户端，封装了流式对话、JSON 输出和单轮补全三种调用方式。
    通过 trust_env=False 绕过系统代理，确保直连 LLM 服务端点。
    """

    def __init__(
        self,
        *,
        base_url: str | None = None,
        api_key: str | None = None,
        model: str | None = None,
        temperature: float | None = None,
        max_tokens: int | None = None,
    ) -> None:
        self._model = model or settings.llm_model
        self._temperature = temperature if temperature is not None else settings.llm_temperature
        self._max_tokens = max_tokens or settings.llm_max_tokens
        # Use trust_env=False to bypass system proxy settings so the client
        # connects directly to the LLM endpoint (e.g. DashScope).
        self._client = AsyncOpenAI(
            base_url=base_url or settings.llm_base_url,
            api_key=api_key or settings.llm_api_key,
            http_client=httpx.AsyncClient(trust_env=False),
        )
        # Build extra_body for provider-specific params
        self._extra_body: dict[str, Any] = {}
        if not settings.llm_enable_thinking:
            self._extra_body["enable_thinking"] = False

    async def close(self) -> None:
        """关闭 LLM 客户端及底层 httpx 连接，释放资源。
        Close the OpenAI client and its underlying httpx.AsyncClient."""
        # AsyncOpenAI.close() 会关闭内部 httpx 客户端，
        # 此处显式关闭自定义 httpx 客户端以确保连接池释放
        http_client = self._client._client  # the httpx.AsyncClient we injected
        await self._client.close()
        if http_client and not http_client.is_closed:
            await http_client.aclose()

    # ── Streaming chat ───────────────────────────────────────────────────

    async def chat(
        self,
        messages: list[dict[str, str]],
        *,
        model: str | None = None,
        temperature: float | None = None,
        max_tokens: int | None = None,
        extra_body: dict[str, Any] | None = None,
    ) -> AsyncIterator[str]:
        """Yield text chunks from a streaming chat completion.

        Parameters
        ----------
        messages:
            OpenAI-style message list, e.g.
            ``[{"role": "system", "content": "..."}, {"role": "user", "content": "..."}]``
        extra_body:
            Additional provider-specific body fields.  Merged with the
            instance-level ``_extra_body`` (per-call values take precedence).
        """
        merged_extra = {**self._extra_body, **(extra_body or {})}
        response = await self._client.chat.completions.create(
            model=model or self._model,
            messages=messages,
            temperature=temperature if temperature is not None else self._temperature,
            max_tokens=max_tokens or self._max_tokens,
            stream=True,
            extra_body=merged_extra if merged_extra else None,
        )

        # 流式读取设置 5 分钟超时，防止 LLM 长时间无响应导致连接挂起
        # Timeout streaming reads to avoid hanging on unresponsive LLM
        async with asyncio.timeout(300):
            async for chunk in response:
                if not chunk.choices:
                    continue
                delta = chunk.choices[0].delta
                if delta.content:
                    yield delta.content

    # ── Non-streaming JSON output ────────────────────────────────────────

    async def chat_json(
        self,
        messages: list[dict[str, str]],
        *,
        model: str | None = None,
        temperature: float | None = None,
        max_tokens: int | None = None,
        extra_body: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Request a single chat completion with ``response_format=json_object``.

        Returns the parsed JSON dict from the assistant's reply.
        """
        import json as _json

        merged_extra = {**self._extra_body, **(extra_body or {})}
        response = await self._client.chat.completions.create(
            model=model or self._model,
            messages=messages,
            temperature=temperature if temperature is not None else self._temperature,
            max_tokens=max_tokens or self._max_tokens,
            response_format={"type": "json_object"},
            stream=False,
            extra_body=merged_extra if merged_extra else None,
        )

        content = response.choices[0].message.content or "{}"
        finish_reason = response.choices[0].finish_reason

        # Handle truncated JSON when LLM hits max_tokens limit
        if finish_reason == "length":
            logger.warning(
                "llm_response_truncated",
                max_tokens=max_tokens or self._max_tokens,
                content_tail=content[-200:],
            )
            # Attempt to salvage truncated JSON by closing open structures
            content = self._repair_truncated_json(content)

        try:
            return _json.loads(content)
        except _json.JSONDecodeError:
            logger.error("llm_json_parse_error", raw=content[:500])
            # JSON 解析失败时抛出异常，让调用方决定如何处理（而非静默返回错误 dict）
            # Raise instead of returning an error dict so callers can handle explicitly
            raise ValueError(
                f"Failed to parse LLM JSON response: {content[:200]}"
            )

    @staticmethod
    def _repair_truncated_json(text: str) -> str:
        """Best-effort repair of truncated JSON from LLM output.

        When the LLM hits max_tokens, the JSON is cut mid-stream.
        Strategy: strip the last incomplete value/key, then close all
        open brackets/braces so json.loads can parse what we have.
        """
        import json as _json

        # First try: maybe it's already valid
        try:
            _json.loads(text)
            return text
        except _json.JSONDecodeError:
            pass

        # Strip trailing incomplete token (partial string, number, etc.)
        # Find last complete structural char
        stripped = text.rstrip()
        # Remove trailing comma if present
        if stripped.endswith(","):
            stripped = stripped[:-1]

        # Progressively strip trailing chars until we find a structural char
        # that could be a valid JSON value ending
        while stripped and stripped[-1] not in ']}"0123456789truefalsn':
            stripped = stripped[:-1]

        # If we ended mid-string, try to close it
        # Count unmatched quotes
        in_string = False
        escaped = False
        for ch in stripped:
            if escaped:
                escaped = False
                continue
            if ch == '\\':
                escaped = True
                continue
            if ch == '"':
                in_string = not in_string
        if in_string:
            stripped += '"'

        # Now close open brackets/braces
        open_stack: list[str] = []
        in_str = False
        esc = False
        for ch in stripped:
            if esc:
                esc = False
                continue
            if ch == '\\' and in_str:
                esc = True
                continue
            if ch == '"':
                in_str = not in_str
                continue
            if in_str:
                continue
            if ch in '{[':
                open_stack.append(ch)
            elif ch == '}' and open_stack and open_stack[-1] == '{':
                open_stack.pop()
            elif ch == ']' and open_stack and open_stack[-1] == '[':
                open_stack.pop()

        # Close in reverse order
        for bracket in reversed(open_stack):
            stripped += ']' if bracket == '[' else '}'

        # Remove trailing comma before closing bracket (invalid JSON)
        import re
        stripped = re.sub(r',\s*([}\]])', r'\1', stripped)

        try:
            _json.loads(stripped)
            logger.info("llm_truncated_json_repaired", salvaged_len=len(stripped))
            return stripped
        except _json.JSONDecodeError:
            logger.warning("llm_truncated_json_repair_failed")
            # Return original — caller will handle the parse error
            return text

    # ── Simple single-shot helper ────────────────────────────────────────

    async def complete(
        self,
        prompt: str,
        *,
        system: str | None = None,
        model: str | None = None,
        temperature: float | None = None,
        max_tokens: int | None = None,
        extra_body: dict[str, Any] | None = None,
    ) -> str:
        """Non-streaming single-shot completion. Returns the full text."""
        messages: list[dict[str, str]] = []
        if system:
            messages.append({"role": "system", "content": system})
        messages.append({"role": "user", "content": prompt})

        merged_extra = {**self._extra_body, **(extra_body or {})}
        response = await self._client.chat.completions.create(
            model=model or self._model,
            messages=messages,
            temperature=temperature if temperature is not None else self._temperature,
            max_tokens=max_tokens or self._max_tokens,
            stream=False,
            extra_body=merged_extra if merged_extra else None,
        )
        return response.choices[0].message.content or ""
