"""研究引擎检索层。

封装关键词提取、ES 混合检索、图谱文档补充等检索相关逻辑，
由 ResearchEngine 通过组合方式使用。
"""

from __future__ import annotations

from typing import Any

from app.config import settings
from app.core.embedding import EmbeddingService
from app.core.graph_query_planner import GraphQueryPlanner
from app.core.graph_query_service import GraphQueryService
from app.core.permission import PermissionContext
from app.core.research_formatter import (
    _aggregate_es_hits,
)
from app.infrastructure.es_client import ESClient, HYBRID_RRF_PIPELINE
from app.infrastructure.llm_client import LLMClient
from app.prompts.research_prompts import (
    KEYWORD_EXTRACTION_SYSTEM,
    KEYWORD_EXTRACTION_USER,
)
from app.utils.logger import get_logger

logger = get_logger(__name__)

# Maximum number of documents to include in LLM context
_MAX_CONTEXT_DOCS = 12
# Maximum highlight passages per document in context
_MAX_PASSAGES_PER_DOC = 3
# Maximum characters per passage
_MAX_PASSAGE_CHARS = 600


class ResearchRetriever:
    """研究检索器，封装关键词提取和多源文档检索逻辑。

    Parameters
    ----------
    es_client:
        Async Elasticsearch 客户端，用于混合检索。
    embedding_service:
        向量化服务，用于 kNN 查询。
    graph_service:
        Neo4j 图谱查询服务，用于图谱文档发现。
    llm_client:
        LLM 客户端，用于关键词提取。
    planner:
        图谱查询规划器，用于意图识别和证据收集。
    """

    def __init__(
        self,
        es_client: ESClient,
        embedding_service: EmbeddingService,
        graph_service: GraphQueryService,
        llm_client: LLMClient,
        planner: GraphQueryPlanner,
    ) -> None:
        self._es = es_client
        self._embedding = embedding_service
        self._graph = graph_service
        self._llm = llm_client
        self._planner = planner

    # ==================================================================
    # Keyword extraction
    # ==================================================================

    async def extract_keywords(self, question: str) -> list[str]:
        """通过 LLM 从问题中提取检索关键词（含机构名和地域名），用于 ES 和图谱检索。

        Use LLM to extract search keywords (fast / cheap call)."""
        try:
            result = await self._llm.chat_json(
                [
                    {"role": "system", "content": KEYWORD_EXTRACTION_SYSTEM},
                    {
                        "role": "user",
                        "content": KEYWORD_EXTRACTION_USER.format(question=question),
                    },
                ],
                temperature=0.0,
                max_tokens=512,
            )
            keywords: list[str] = result.get("keywords", [])
            orgs: list[str] = result.get("organizations", [])
            regions: list[str] = result.get("regions", [])
            return (keywords + orgs + regions)[:8]
        except Exception as exc:
            logger.warning("keyword_extraction_failed", error=str(exc))
            return []

    # ==================================================================
    # ES retrieval
    # ==================================================================

    async def es_search(
        self, question: str, perm: PermissionContext
    ) -> list[dict[str, Any]]:
        """执行 ES 混合检索（BM25 + kNN RRF 融合），按 doc_id 聚合后返回文档级结果。

        基于 ESClient.should_use_hybrid 判断是否走 hybrid 路径，
        通过 ESClient.hybrid_search 自动探测/熔断，无需字符串匹配回退。

        Run ES hybrid search and return doc-level results."""
        try:
            acl_filter = perm.build_es_filter()
            combined_filter: dict[str, Any] = {
                "bool": {"must": [acl_filter]}
            }
            fetch_size = _MAX_CONTEXT_DOCS * 5

            bm25_query: dict[str, Any] = {
                "bool": {
                    "must": [
                        {
                            "multi_match": {
                                "query": question,
                                "fields": ["title^3", "content"],
                                "type": "best_fields",
                            }
                        }
                    ],
                    "filter": [combined_filter],
                }
            }

            highlight_cfg: dict[str, Any] = {
                "fields": {
                    "content": {
                        "fragment_size": _MAX_PASSAGE_CHARS,
                        "number_of_fragments": _MAX_PASSAGES_PER_DOC,
                    }
                }
            }

            # BM25-only 查询体（hybrid 失败时回退使用）
            bm25_body: dict[str, Any] = {
                "size": fetch_size,
                "query": bm25_query,
                "_source": {"excludes": ["content_vector"]},
                "highlight": highlight_cfg,
            }

            if self._es.should_use_hybrid:
                # 尝试生成向量，失败则降级纯 BM25（与 search_engine 模式对齐）
                query_vector: list[float] | None = None
                try:
                    query_vector = await self._embedding.embed_single(question)
                except Exception as emb_err:
                    logger.warning(
                        "research_embedding_failed_bm25_fallback",
                        error=str(emb_err),
                        question=question[:40],
                    )

                if query_vector is not None:
                    hybrid_body: dict[str, Any] = {
                        "size": fetch_size,
                        "query": {
                            "hybrid": {
                                "queries": [
                                    bm25_query,
                                    {
                                        "knn": {
                                            "content_vector": {
                                                "vector": query_vector,
                                                "k": fetch_size,
                                                "filter": combined_filter,
                                            },
                                        },
                                    },
                                ],
                            },
                        },
                        "_source": {"excludes": ["content_vector"]},
                        "highlight": highlight_cfg,
                    }
                    response, ok = await self._es.hybrid_search(
                        hybrid_body,
                        index=settings.es_chunk_index,
                        pipeline=HYBRID_RRF_PIPELINE,
                    )
                    if not ok:
                        # hybrid 失败 → 回退到纯 BM25
                        logger.info("research_hybrid_to_bm25_fallback")
                        response = await self._es.raw.search(
                            index=settings.es_chunk_index, body=bm25_body,
                        )
                else:
                    # embedding 失败，直接走 BM25
                    response = await self._es.raw.search(
                        index=settings.es_chunk_index, body=bm25_body,
                    )
            else:
                # RRF 已标记为不可用，直接走纯 BM25
                response = await self._es.raw.search(
                    index=settings.es_chunk_index, body=bm25_body,
                )

            raw = response if isinstance(response, dict) else response.body
            docs = _aggregate_es_hits(raw, _MAX_CONTEXT_DOCS)
            await self._enrich_from_meta(docs)
            return docs
        except Exception as exc:
            logger.error("research_es_search_failed", error=str(exc))
            return []

    # ==================================================================
    # Graph-enriched doc fetch
    # ==================================================================

    async def fetch_graph_docs(
        self,
        doc_ids: list[str],
        question: str | None,
        perm: PermissionContext,
    ) -> list[dict[str, Any]]:
        """从 ES 获取图谱发现的文档的高亮摘录，按权限过滤后返回文档级结果。

        Fetch top chunks for a set of doc_ids found via graph traversal."""
        if not doc_ids:
            return []
        try:
            acl_filter = perm.build_es_filter()
            query: dict[str, Any] = {
                "bool": {
                    "must": [{"terms": {"doc_ids": doc_ids}}],
                    "filter": [acl_filter],
                }
            }
            if question:
                query["bool"]["should"] = [
                    {
                        "multi_match": {
                            "query": question,
                            "fields": ["title^2", "content"],
                            "type": "best_fields",
                        }
                    }
                ]
            body: dict[str, Any] = {
                "size": max(len(doc_ids) * 10, 50),
                "query": query,
                "_source": {"excludes": ["content_vector"]},
                "highlight": {
                    "fields": {
                        "content": {
                            "fragment_size": _MAX_PASSAGE_CHARS,
                            "number_of_fragments": _MAX_PASSAGES_PER_DOC,
                        }
                    }
                },
            }
            response = await self._es.raw.search(
                index=settings.es_chunk_index, body=body
            )
            raw = response if isinstance(response, dict) else response.body
            docs = _aggregate_es_hits(raw, len(doc_ids), preferred_doc_ids=doc_ids)

            # Ensure all requested doc_ids are represented — some low-score
            # docs may be pushed out of the top results by high-score ones.
            found_ids = {d["doc_id"] for d in docs}
            missing_ids = [did for did in doc_ids if did not in found_ids]
            if missing_ids:
                for _mid in missing_ids:
                    _fallback_body: dict[str, Any] = {
                        "size": 3,
                        "query": {
                            "bool": {
                                "must": [{"term": {"doc_ids": _mid}}],
                                "filter": [acl_filter],
                            }
                        },
                        "_source": {"excludes": ["content_vector"]},
                        "highlight": {
                            "fields": {
                                "content": {
                                    "fragment_size": _MAX_PASSAGE_CHARS,
                                    "number_of_fragments": _MAX_PASSAGES_PER_DOC,
                                }
                            }
                        },
                    }
                    _fb_resp = await self._es.raw.search(
                        index=settings.es_chunk_index, body=_fallback_body,
                    )
                    _fb_raw = _fb_resp if isinstance(_fb_resp, dict) else _fb_resp.body
                    _fb_docs = _aggregate_es_hits(_fb_raw, 1, preferred_doc_ids=[_mid])
                    docs.extend(_fb_docs)

            await self._enrich_from_meta(docs)
            for doc in docs:
                doc["_source_type"] = "graph"
            return docs
        except Exception as exc:
            logger.error("graph_doc_fetch_failed", error=str(exc))
            return []

    async def _enrich_from_meta(self, docs: list[dict[str, Any]]) -> None:
        """从 meta 索引批量补充 title / doc_number / source_url 等字段。"""

        def _needs_meta_enrichment(doc: dict[str, Any]) -> bool:
            return bool(
                doc.get("doc_id")
                and (
                    not doc.get("title")
                    or not doc.get("source_url")
                    or not doc.get("source_metadata")
                    or not doc.get("source_system")
                )
            )

        doc_ids = [d["doc_id"] for d in docs if _needs_meta_enrichment(d)]
        if not doc_ids:
            return
        try:
            body = {"ids": doc_ids}
            resp = await self._es.raw.mget(
                index=settings.es_meta_index,
                body=body,
                _source=[
                    "title",
                    "doc_number",
                    "issuing_org",
                    "doc_type",
                    "publish_date",
                    "source_url",
                    "source_metadata",
                    "source_system",
                    "source_site_code",
                    "source_target_code",
                ],
            )
            raw = resp if isinstance(resp, dict) else resp.body
            meta_map: dict[str, dict] = {}
            for item in raw.get("docs", []):
                if item.get("found"):
                    meta_map[item["_id"]] = item["_source"]

            for doc in docs:
                meta = meta_map.get(doc["doc_id"])
                if not meta:
                    continue
                if not doc.get("title"):
                    doc["title"] = meta.get("title", "")
                if not doc.get("doc_number"):
                    doc["doc_number"] = meta.get("doc_number")
                if not doc.get("issuing_org"):
                    doc["issuing_org"] = meta.get("issuing_org")
                if not doc.get("doc_type"):
                    doc["doc_type"] = meta.get("doc_type")
                if not doc.get("publish_date"):
                    doc["publish_date"] = meta.get("publish_date")
                if not doc.get("source_url"):
                    doc["source_url"] = meta.get("source_url")
                if not doc.get("source_metadata"):
                    doc["source_metadata"] = meta.get("source_metadata") or {}
                if not doc.get("source_system"):
                    doc["source_system"] = meta.get("source_system")
                if not doc.get("source_site_code"):
                    doc["source_site_code"] = meta.get("source_site_code")
                if not doc.get("source_target_code"):
                    doc["source_target_code"] = meta.get("source_target_code")
        except Exception as exc:
            logger.warning("enrich_from_meta_failed", error=str(exc))

