"""Knowledge graph builder — LLM entity extraction → normalisation → Neo4j write.

Pipeline per document:
    1. Call LLM with structured entity-extraction prompt → JSON
    2. Normalise / deduplicate / validate entities (scene-aware)
    3. Regex-scan content for additional document-number references
    4. MERGE the subgraph into Neo4j (idempotent)

知识图谱构建器模块。
负责单个文档的图谱构建全流程：
1. 调用 LLM 进行结构化实体抽取（返回 JSON）
2. 实体名称规范化、去重、类型校验（支持场景化抽取集合）
3. 正则扫描正文发现额外的公文文号引用
4. 幂等写入 Neo4j（MERGE 语义）
抽取结果包括机构、事项、条件、材料、时限、政策主题等实体类型，
以及文档间的引用/依据/修订/废止等关系。
"""

from __future__ import annotations

import hashlib
from typing import Any

from app.core.graph_schema_loader import get_schema
from app.infrastructure.llm_client import LLMClient
from app.infrastructure.neo4j_client import Neo4jClient
from app.prompts.entity_extraction import (
    ENTITY_EXTRACTION_SYSTEM,
    ENTITY_EXTRACTION_USER,
    build_scene_system_prompt,
    build_system_prompt,
)
from app.utils.doc_code import iter_doc_codes
from app.utils.logger import get_logger
from app.utils.text_normalize import normalize_title

logger = get_logger(__name__)

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

# Entities that receive auto-generated internal ID fields.
# The ID is deterministic (based on normalized name) and does NOT affect
# the Neo4j merge key (which remains ``name``) for most types.
# Person is handled separately via _generate_person_id().
_AUTO_ID_FIELDS: dict[str, str] = {
    # Phase 1 (existing)
    "Matter": "matter_id",
    "Condition": "condition_id",
    "Material": "material_id",
    "TimeLimit": "time_limit_id",
    "TargetGroup": "target_group_id",
    # Phase 0 (new core entities)
    "Policy": "policy_id",
    "Task": "task_id",
    "Project": "project_id",
    "System": "system_id",
    "DataResource": "data_resource_id",
    "Indicator": "indicator_id",
    "Budget": "budget_id",
    "Industry": "industry_id",
    # Phase 3 (office extension)
    "Event": "event_id",
    "Mechanism": "mechanism_id",
    "Standard": "standard_id",
    "Infrastructure": "infrastructure_id",
    "Technology": "technology_id",
}


def _generate_entity_id(label: str, normalized_name: str) -> str:
    """Generate a deterministic internal ID from label + normalized name.

    Format: ``{label_lower}_{sha256_prefix}``  (e.g. ``matter_a1b2c3d4e5f6``).
    """
    # 安全修复：使用 SHA-256 替代 MD5，避免哈希碰撞风险
    h = hashlib.sha256(normalized_name.encode("utf-8")).hexdigest()[:12]
    return f"{label.lower()}_{h}"


def _generate_person_id(
    name: str, props: dict[str, Any], doc_id: str
) -> str:
    """Generate a deterministic person_id using name + context attributes.

    Uses name, serving_org, position, and a doc_id prefix as the hash seed.
    Degradation rules:
    - Missing serving_org/position → empty string in hash (conservative)
    - Both missing → hash(name + doc_id_prefix) → same-doc same-name merges,
      cross-doc does not merge (safest default)
    """
    org = props.get("serving_org", "")
    position = props.get("position", "")
    seed = f"{name}|{org}|{position}|{doc_id[:8]}"
    h = hashlib.sha256(seed.encode("utf-8")).hexdigest()[:12]
    return f"person_{h}"


def _get_valid_entity_types_for_scene(scene_type: str = "") -> set[str]:
    """Return entity types valid for the given scene (or active set)."""
    schema = get_schema()
    if scene_type:
        return schema.entity_type_names_for_scene(scene_type)
    return schema.entity_type_names()


def _get_all_valid_entity_types() -> set[str]:
    """Return ALL valid entity types (runtime support set, all phases)."""
    return get_schema().all_entity_type_names_unfiltered()


def _get_all_valid_rel_types() -> set[str]:
    """Return ALL valid relationship types (runtime support set, all phases)."""
    return get_schema().all_rel_type_names_unfiltered()


def _resolve_entity_key(label: str, name: str, props: dict[str, Any]) -> str:
    """Resolve the actual merge key value for an entity based on its label's key_property.

    For key_property == "name" types, returns the entity name directly.
    For key_property != "name" types (e.g. Person with person_id), returns
    the corresponding property value.

    This is the **unified key resolver** — all relationship assembly must use
    this to compute from_key / to_key, ensuring non-name primary key entities
    (like Person) have their relationships connected correctly.
    """
    if label == "Document":
        return props.get("doc_id", name)
    schema = get_schema()
    et_map = schema.all_entity_type_map_unfiltered()
    et_def = et_map.get(label, {})
    key_prop = et_def.get("key_property", "name")
    return props.get(key_prop, name)


class GraphBuilder:
    """Extracts entities / relationships from a government document and
    writes the resulting subgraph into Neo4j.

    Parameters
    ----------
    llm_client:
        OpenAI-compatible async LLM client for entity extraction.
    neo4j_client:
        Async Neo4j driver wrapper.
    max_content_chars:
        Maximum characters of body text sent to the LLM.
        Longer texts are truncated with a marker.
    """

    def __init__(
        self,
        llm_client: LLMClient,
        neo4j_client: Neo4jClient,
        *,
        max_content_chars: int = 12_000,
    ) -> None:
        self._llm = llm_client
        self._neo4j = neo4j_client
        self._max_chars = max_content_chars

    # ==================================================================
    # Public API
    # ==================================================================

    async def build_graph(
        self,
        doc_id: str,
        metadata: dict[str, Any],
        content: str,
        *,
        scene_type: str = "",
    ) -> dict[str, Any]:
        """Full graph-construction pipeline for one document.

        Args:
            doc_id:     Unique document identifier.
            metadata:   Dict with title, doc_number, issuing_org, doc_type,
                        signer, subject_words, publish_date, etc.
            content:    Full-text body of the document.
            scene_type: Document scene identifier for scene-based extraction
                        (e.g. "leader_speech_city").  Empty → default active set.

        Returns:
            Result dict with ``status``, ``entity_count``, ``relation_count``,
            ``referenced_doc_count``.
        """
        logger.info("graph_build_start", doc_id=doc_id, scene_type=scene_type or "(default)")

        try:
            # Step 1 – LLM entity extraction (scene-aware prompt)
            raw = await self._extract_via_llm(metadata, content, scene_type=scene_type)
            if "error" in raw:
                logger.error("graph_llm_error", doc_id=doc_id, error=raw["error"])
                return {"doc_id": doc_id, "status": "failed", "error": raw["error"]}

            # Step 2 – Normalise / validate entities (scene-aware type set)
            entities = self._normalise_entities(
                raw.get("entities") or [],
                scene_type=scene_type,
                doc_id=doc_id,
            )
            relations = raw.get("relations") or []
            ref_numbers: list[str] = raw.get("referenced_doc_numbers") or []

            # Step 3 – Regex scan for additional doc-number references
            scanned = self._scan_doc_numbers(content)
            # Merge lists, preserve order, deduplicate
            all_refs = list(dict.fromkeys(ref_numbers + scanned))

            # Step 3.5 – Generate normalized_title if missing
            if "normalized_title" not in metadata or not metadata["normalized_title"]:
                metadata["normalized_title"] = normalize_title(
                    metadata.get("title", "")
                )

            # Step 4 – Write subgraph to Neo4j (full runtime set validation)
            await self._write_to_neo4j(
                doc_id=doc_id,
                metadata=metadata,
                entities=entities,
                relations=relations,
                referenced_doc_numbers=all_refs,
            )

            # Scene statistics — entity type distribution
            raw_entity_count = len(raw.get("entities") or [])
            entity_type_dist: dict[str, int] = {}
            for ent in entities:
                t = ent.get("type", "unknown")
                entity_type_dist[t] = entity_type_dist.get(t, 0) + 1
            filtered_count = raw_entity_count - len(entities)

            result = {
                "doc_id": doc_id,
                "status": "completed",
                "entity_count": len(entities),
                "relation_count": len(relations),
                "referenced_doc_count": len(all_refs),
                "entities": entities,
                "relations": relations,
                "referenced_doc_numbers": all_refs,
                "prompt_template": {
                    "system_prompt": self._build_dynamic_system_prompt(scene_type=scene_type),
                    "user_prompt_template": ENTITY_EXTRACTION_USER,
                },
            }
            logger.info(
                "graph_build_complete",
                doc_id=doc_id,
                scene_type=scene_type or "(default)",
                entity_count=len(entities),
                relation_count=len(relations),
                referenced_doc_count=len(all_refs),
                raw_entity_count=raw_entity_count,
                filtered_entity_count=filtered_count,
                entity_type_distribution=entity_type_dist,
            )
            return result

        except Exception as exc:
            # 记录完整异常堆栈以便排查，同时返回 dict 保持接口兼容
            # Log full exception traceback for debugging while returning dict for API compatibility
            logger.exception("graph_build_failed", doc_id=doc_id, error=str(exc))
            return {"doc_id": doc_id, "status": "failed", "error": str(exc)}

    # ==================================================================
    # LLM extraction
    # ==================================================================

    async def _extract_via_llm(
        self,
        metadata: dict[str, Any],
        content: str,
        *,
        scene_type: str = "",
    ) -> dict[str, Any]:
        """Send document to LLM and return parsed JSON extraction result."""
        truncated = content[: self._max_chars]
        if len(content) > self._max_chars:
            truncated += "\n\n[... 正文已截断 ...]"

        subject_words: list[str] = metadata.get("subject_words") or []
        user_prompt = ENTITY_EXTRACTION_USER.format(
            title=metadata.get("title") or "未知",
            doc_number=metadata.get("doc_number") or "未知",
            issuing_org=metadata.get("issuing_org") or "未知",
            doc_type=metadata.get("doc_type") or "未知",
            signer=metadata.get("signer") or "未知",
            subject_words="、".join(subject_words) if subject_words else "无",
            content=truncated,
        )

        # Build system prompt dynamically from type cache (scene-aware)
        system_prompt = self._build_dynamic_system_prompt(scene_type=scene_type)

        return await self._llm.chat_json(
            [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ],
            temperature=0.1,   # Low temperature for deterministic extraction
            max_tokens=10240,
        )

    @staticmethod
    def _build_dynamic_system_prompt(*, scene_type: str = "") -> str:
        """Build extraction prompt from graph_schema.yaml, with fallback.

        When *scene_type* is provided, uses scene-based type sets to include
        phase_3 entities/relationships as needed.
        """
        try:
            return build_scene_system_prompt(scene_type)
        except Exception:
            # 从 graph_schema.yaml 构建动态提示词失败，回退到硬编码默认提示词
            logger.warning("Failed to build dynamic system prompt, using default", exc_info=True)
            return ENTITY_EXTRACTION_SYSTEM

    # ==================================================================
    # Entity normalisation
    # ==================================================================

    def _normalise_entities(
        self,
        raw: list[dict[str, Any]],
        *,
        scene_type: str = "",
        doc_id: str = "",
    ) -> list[dict[str, Any]]:
        """Validate, clean entity names, and deduplicate the list.

        Normalization rules are loaded from ``graph_schema.yaml``.
        Uses the **scene-based type set** for validation — ensuring phase_3
        entities extracted via scene-based prompts are NOT discarded.

        Parameters
        ----------
        raw:
            Raw entity list from LLM output.
        scene_type:
            Document scene identifier.  Empty → active-phases set.
        doc_id:
            Document ID, used for Person ID generation.
        """
        schema = get_schema()
        valid_types = _get_valid_entity_types_for_scene(scene_type)
        seen: set[str] = set()
        out: list[dict[str, Any]] = []
        filtered_types: dict[str, int] = {}

        for ent in raw:
            etype = (ent.get("type") or "").strip()
            name = (ent.get("name") or "").strip()
            props = dict(ent.get("properties") or {})

            if not name or etype not in valid_types:
                if etype:
                    filtered_types[etype] = filtered_types.get(etype, 0) + 1
                continue

            # ── config-driven normalisation ──
            rule = schema.get_norm_rule(etype)
            if rule.get("strip_punctuation"):
                name = _strip_punct(name)
            if rule.get("trim_whitespace"):
                name = name.strip()
            max_len = rule.get("max_length")
            if max_len and len(name) > max_len:
                name = name[:max_len]

            if not name:
                continue

            # ── deduplication key ──
            dedup_key = f"{etype}::{name}"
            if dedup_key in seen:
                continue
            seen.add(dedup_key)

            out.append({"type": etype, "name": name, "properties": props})

        if filtered_types:
            logger.debug(
                "entity_types_filtered",
                doc_id=doc_id,
                scene_type=scene_type or "(default)",
                valid_types=sorted(valid_types),
                filtered_type_counts=filtered_types,
            )

        return out

    # ==================================================================
    # Document-number scanning
    # ==================================================================

    @staticmethod
    def _scan_doc_numbers(content: str) -> list[str]:
        """Find government document-number patterns in raw text via regex."""
        return list(iter_doc_codes(content))

    # ==================================================================
    # Neo4j write
    # ==================================================================

    async def _write_to_neo4j(
        self,
        doc_id: str,
        metadata: dict[str, Any],
        entities: list[dict[str, Any]],
        relations: list[dict[str, Any]],
        referenced_doc_numbers: list[str],
    ) -> None:
        """Convert LLM output to Neo4j-compatible structures and persist.

        将 LLM 抽取结果转换为 Neo4j 可写入的实体/关系结构。
        处理要点：
        - CURRENT_DOC 占位符替换为实际 doc_id
        - 被引用文档创建为 is_placeholder=True 的占位节点
        - 已有强类型关系（BASED_ON/AMENDS/REPEALS）的文号不再重复创建 REFERENCES
        - 使用全量运行时支持集合做类型校验（不受 active_phases 限制）
        - 统一主键解析：通过 _resolve_entity_key() 确保非 name 主键实体关系正确连接
        """
        # 使用全量运行时支持集合（所有 phase），而非 active_phases 过滤后的集合
        valid_rel_types = _get_all_valid_rel_types()
        valid_entity_types = _get_all_valid_entity_types()
        all_valid_labels = valid_entity_types | {"Document"}

        # ── build entity list ──
        neo_entities: list[dict[str, Any]] = []
        for ent in entities:
            props = dict(ent["properties"])
            props["name"] = ent["name"]
            label = ent["type"]

            # Person uses special ID generation (name + serving_org + position + doc_id)
            if label == "Person":
                if "person_id" not in props:
                    props["person_id"] = _generate_person_id(
                        ent["name"], props, doc_id
                    )
            else:
                # Auto-generate internal *_id for other entity types
                id_field = _AUTO_ID_FIELDS.get(label)
                if id_field and id_field not in props:
                    props[id_field] = _generate_entity_id(label, ent["name"])

            neo_entities.append({"label": label, "properties": props})

        # ── unified key resolution: build (label, name) -> key_value mapping ──
        # This ensures all non-name primary key entities (e.g. Person with
        # person_id) have their relationships connected via the correct key.
        entity_key_map: dict[tuple[str, str], str] = {}
        for neo_ent in neo_entities:
            label = neo_ent["label"]
            props = neo_ent["properties"]
            name = props.get("name", "")
            key_val = _resolve_entity_key(label, name, props)
            entity_key_map[(label, name)] = key_val

        # ── build relationship list ──
        neo_rels: list[dict[str, Any]] = []
        for rel in relations:
            src_type = (rel.get("source_type") or "").strip()
            src_name = (rel.get("source_name") or "").strip()
            tgt_type = (rel.get("target_type") or "").strip()
            tgt_name = (rel.get("target_name") or "").strip()
            rtype = (rel.get("relation") or "").strip()

            if not all([src_type, src_name, tgt_type, tgt_name, rtype]):
                continue
            if rtype not in valid_rel_types:
                continue
            if src_type not in all_valid_labels or tgt_type not in all_valid_labels:
                continue

            # Resolve CURRENT_DOC placeholder to the actual doc_id
            if src_name == "CURRENT_DOC":
                src_type, src_name = "Document", doc_id
            if tgt_name == "CURRENT_DOC":
                tgt_type, tgt_name = "Document", doc_id

            # For Document→Document relations where target is a referenced doc
            # (not CURRENT_DOC), ensure placeholder node exists and use ref: ID.
            if (
                tgt_type == "Document"
                and tgt_name != doc_id
                and not tgt_name.startswith("ref:")
            ):
                placeholder_id = f"ref:{tgt_name}"
                # Ensure placeholder entity is created
                neo_entities.append(
                    {
                        "label": "Document",
                        "properties": {
                            "doc_id": placeholder_id,
                            "doc_code": tgt_name,
                            "is_placeholder": True,
                        },
                    }
                )
                tgt_name = placeholder_id

            # ── resolve connection keys via unified key resolver ──
            # For Document types (already resolved above), use the name directly
            # (which is doc_id after CURRENT_DOC substitution).
            # For non-Document types, look up from entity_key_map.
            if src_type == "Document":
                from_key = src_name
            else:
                from_key = entity_key_map.get((src_type, src_name), src_name)

            if tgt_type == "Document":
                to_key = tgt_name
            else:
                to_key = entity_key_map.get((tgt_type, tgt_name), tgt_name)

            neo_rels.append(
                {
                    "from_label": src_type,
                    "from_key": from_key,
                    "to_label": tgt_type,
                    "to_key": to_key,
                    "type": rtype,
                    "properties": rel.get("properties") or {},
                }
            )

        # ── referenced document stubs ──
        # Collect doc_numbers that already have a typed relation
        # (BASED_ON/AMENDS/REPEALS) so we don't create a duplicate REFERENCES.
        # Keys may be raw doc numbers or ref:-prefixed IDs.
        typed_doc_targets: set[str] = set()
        for rel in neo_rels:
            if (
                rel["to_label"] == "Document"
                and rel["type"] in ("BASED_ON", "AMENDS", "REPEALS")
            ):
                key = rel["to_key"]
                typed_doc_targets.add(key)
                # Also add the raw form if it's a ref: ID, and vice versa
                if key.startswith("ref:"):
                    typed_doc_targets.add(key[4:])
                else:
                    typed_doc_targets.add(f"ref:{key}")

        own_code = metadata.get("doc_code") or metadata.get("doc_number", "")
        for ref_num in referenced_doc_numbers:
            if ref_num == own_code:
                continue
            placeholder_id = f"ref:{ref_num}"
            # Skip REFERENCES if this doc already has a stronger typed relation
            if ref_num in typed_doc_targets or placeholder_id in typed_doc_targets:
                continue
            # Create a lightweight placeholder Document node
            neo_entities.append(
                {
                    "label": "Document",
                    "properties": {
                        "doc_id": placeholder_id,
                        "doc_code": ref_num,
                        "is_placeholder": True,
                    },
                }
            )
            neo_rels.append(
                {
                    "from_label": "Document",
                    "from_key": doc_id,
                    "to_label": "Document",
                    "to_key": placeholder_id,
                    "type": "REFERENCES",
                    "properties": {"doc_code": ref_num},
                }
            )

        await self._neo4j.merge_document_graph(
            doc_id=doc_id,
            metadata=metadata,
            entities=neo_entities,
            relationships=neo_rels,
        )


# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------


def _strip_punct(s: str) -> str:
    """Strip common surrounding quotes and punctuation from entity names."""
    return s.strip(' "\'""''\u3001\u3002\uff0c\uff1b\uff01\uff1f')
