"""Load graph schema definitions from ``config/graph_schema.yaml``.

Provides a cached, validated view of entity types, relationship types,
normalization rules, and knowledge-category mappings.  All other graph
modules should obtain type definitions through this loader rather than
hard-coding them.

Three distinct type-set views are provided:

- **Active set** (``entity_types`` / ``relationship_types``):
  Filtered by ``active_phases``.  Used as the default extraction set.
- **Unfiltered set** (``all_entity_types_unfiltered()`` etc.):
  All types regardless of phase.  Used for Neo4j constraint creation,
  write-path validation, query-layer whitelists, and admin type lists.
- **Scene set** (``entity_types_for_scene(scene)`` etc.):
  ``active_phases + extraction_scenes[scene].extra_phases``.
  Used by the extraction pipeline for scene-specific prompts.

Usage::

    from app.core.graph_schema_loader import get_schema

    schema = get_schema()
    entity_types = schema.entity_types          # list[dict] (active only)
    all_types    = schema.all_entity_types_unfiltered()  # list[dict] (full)
    scene_types  = schema.entity_types_for_scene("leader_speech_city")

图谱 Schema 加载器模块。
从 config/graph_schema.yaml 读取并缓存实体类型、关系类型、
规范化规则和知识分类映射等配置。所有图谱相关模块都应通过
get_schema() 获取类型定义，而非硬编码。
支持按 active_phases 过滤类型定义，实现分阶段功能开放。
支持场景化抽取（extraction_scenes）按需启用额外 phase。
"""

from __future__ import annotations

import re
import threading
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import yaml

from app.utils.logger import get_logger

logger = get_logger(__name__)

_SCHEMA_PATH = Path(__file__).resolve().parent.parent / "config" / "graph_schema.yaml"

_SAFE_IDENTIFIER_RE = re.compile(r"^[A-Za-z_]\w{0,49}$")


# ---------------------------------------------------------------------------
# Public data class
# ---------------------------------------------------------------------------

@dataclass(frozen=True)
class GraphSchema:
    """Immutable snapshot of the graph schema configuration.

    图谱 Schema 的不可变快照，包含实体类型、关系类型、规范化规则等配置。
    提供三层类型集合视图：活跃集合、全量集合、场景化集合。
    """

    # ── active-phases-filtered (default extraction set) ──
    entity_types: list[dict[str, Any]] = field(default_factory=list)
    relationship_types: list[dict[str, Any]] = field(default_factory=list)

    # ── unfiltered (runtime support set — all phases) ──
    # Use tuple for frozen-dataclass compatibility.
    _all_entity_types: tuple[dict[str, Any], ...] = field(default_factory=tuple)
    _all_relationship_types: tuple[dict[str, Any], ...] = field(default_factory=tuple)

    # ── scene-based extraction config ──
    extraction_scenes: dict[str, dict[str, Any]] = field(default_factory=dict)

    # ── shared config ──
    normalization_rules: dict[str, dict[str, Any]] = field(default_factory=dict)
    document_properties: list[str] = field(default_factory=list)
    knowledge_category_mapping: dict[str, str] = field(default_factory=dict)
    doc_types: list[str] = field(default_factory=list)
    active_phases: list[str] = field(default_factory=list)

    # =====================================================================
    # Active-set helpers (filtered by active_phases)
    # =====================================================================

    def entity_type_names(self) -> set[str]:
        """Return set of active entity-type names (excluding Document)."""
        return {et["name"] for et in self.entity_types}

    def rel_type_names(self) -> set[str]:
        """Return set of active relationship-type names."""
        return {rt["name"] for rt in self.relationship_types}

    def entity_type_map(self) -> dict[str, dict[str, Any]]:
        """Return mapping ``name -> full definition`` for active entity types."""
        return {et["name"]: et for et in self.entity_types}

    def rel_type_map(self) -> dict[str, dict[str, Any]]:
        """Return mapping ``name -> full definition`` for active relationship types."""
        return {rt["name"]: rt for rt in self.relationship_types}

    def get_norm_rule(self, entity_name: str) -> dict[str, Any]:
        """Return normalization rule for *entity_name*, or empty dict."""
        return self.normalization_rules.get(entity_name, {})

    def knowledge_category_code(self, category_name: str) -> str:
        """Map a Chinese knowledge-category name to its code."""
        return self.knowledge_category_mapping.get(category_name, "")

    def all_node_labels(self) -> set[str]:
        """Return active node labels including ``Document``."""
        return self.entity_type_names() | {"Document"}

    def all_rel_types(self) -> set[str]:
        """Return active relationship type names."""
        return self.rel_type_names()

    # =====================================================================
    # Unfiltered helpers (runtime support set — all phases)
    # =====================================================================

    def all_entity_types_unfiltered(self) -> list[dict[str, Any]]:
        """Return ALL entity type definitions regardless of active_phases."""
        return list(self._all_entity_types)

    def all_rel_types_unfiltered(self) -> list[dict[str, Any]]:
        """Return ALL relationship type definitions regardless of active_phases."""
        return list(self._all_relationship_types)

    def all_entity_type_names_unfiltered(self) -> set[str]:
        """Return ALL entity-type names (excluding Document)."""
        return {et["name"] for et in self._all_entity_types}

    def all_rel_type_names_unfiltered(self) -> set[str]:
        """Return ALL relationship-type names."""
        return {rt["name"] for rt in self._all_relationship_types}

    def all_node_labels_unfiltered(self) -> set[str]:
        """Return ALL node labels including ``Document``.

        Used by query-layer whitelist and Neo4j constraint creation.
        """
        return self.all_entity_type_names_unfiltered() | {"Document"}

    def all_entity_type_map_unfiltered(self) -> dict[str, dict[str, Any]]:
        """Return full ``name -> definition`` mapping for ALL entity types.

        Used by ``_key_prop()`` and other lookups that must resolve
        key_property for any label (including phase_3).
        """
        return {et["name"]: et for et in self._all_entity_types}

    def all_rel_type_map_unfiltered(self) -> dict[str, dict[str, Any]]:
        """Return full ``name -> definition`` mapping for ALL relationship types."""
        return {rt["name"]: rt for rt in self._all_relationship_types}

    def key_prop_for(self, label: str) -> str:
        """Return the business key property name for *label*.

        Most entity types use ``"name"`` as their key property.
        Exceptions (e.g. Person → ``"person_id"``) are declared in
        ``graph_schema.yaml`` via the ``key_property`` field.

        Falls back to ``"name"`` for unknown labels.
        """
        et = self.all_entity_type_map_unfiltered().get(label)
        if et:
            return et.get("key_property", "name")
        return "name"

    # =====================================================================
    # Scene-based helpers
    # =====================================================================

    def get_scene_phases(self, scene_type: str) -> list[str]:
        """Return ``active_phases + extra_phases`` for *scene_type*, deduped.

        If *scene_type* is empty or not configured, returns ``active_phases``.
        """
        scene_cfg = self.extraction_scenes.get(scene_type) if scene_type else None
        if not scene_cfg:
            return list(self.active_phases)
        extra = scene_cfg.get("extra_phases") or []
        # Preserve order, deduplicate
        seen: set[str] = set()
        result: list[str] = []
        for p in list(self.active_phases) + extra:
            if p not in seen:
                seen.add(p)
                result.append(p)
        return result

    def entity_types_for_scene(self, scene_type: str) -> list[dict[str, Any]]:
        """Return entity types active for *scene_type*."""
        phases = set(self.get_scene_phases(scene_type))
        return [
            et for et in self._all_entity_types
            if et.get("phase", "phase_0") in phases
        ]

    def rel_types_for_scene(self, scene_type: str) -> list[dict[str, Any]]:
        """Return relationship types active for *scene_type*."""
        phases = set(self.get_scene_phases(scene_type))
        return [
            rt for rt in self._all_relationship_types
            if rt.get("phase", "phase_0") in phases
        ]

    def entity_type_names_for_scene(self, scene_type: str) -> set[str]:
        """Return entity-type names active for *scene_type*."""
        return {et["name"] for et in self.entity_types_for_scene(scene_type)}

    def rel_type_names_for_scene(self, scene_type: str) -> set[str]:
        """Return relationship-type names active for *scene_type*."""
        return {rt["name"] for rt in self.rel_types_for_scene(scene_type)}


# ---------------------------------------------------------------------------
# Module-level cache
# ---------------------------------------------------------------------------

_cache: GraphSchema | None = None
# 线程锁保护 _cache 的读写，防止多线程并发加载导致数据竞争
# Thread lock to protect _cache reads/writes against concurrent access
_cache_lock = threading.Lock()


def get_schema(*, force_reload: bool = False) -> GraphSchema:
    """Return the cached :class:`GraphSchema`.

    On first call (or when *force_reload* is ``True``) the YAML file is
    read, validated and cached.  Subsequent calls return the same object.
    """
    global _cache
    with _cache_lock:
        if _cache is not None and not force_reload:
            return _cache

        _cache = _load_and_validate()
        return _cache


def reload_schema() -> GraphSchema:
    """Force-reload the schema from disk and return it."""
    return get_schema(force_reload=True)


# ---------------------------------------------------------------------------
# Internal loading / validation
# ---------------------------------------------------------------------------

def _load_and_validate() -> GraphSchema:
    if not _SCHEMA_PATH.exists():
        raise FileNotFoundError(
            f"Graph schema config not found: {_SCHEMA_PATH}"
        )

    with open(_SCHEMA_PATH, "r", encoding="utf-8") as fh:
        raw: dict[str, Any] = yaml.safe_load(fh) or {}

    # 基本字段校验：确保 entity_types 和 relationship_types 存在且为 list
    # Basic field validation: entity_types and relationship_types must exist and be lists
    for required_field in ("entity_types", "relationship_types"):
        value = raw.get(required_field)
        if value is not None and not isinstance(value, list):
            raise TypeError(
                f"graph_schema.yaml: '{required_field}' must be a list, "
                f"got {type(value).__name__}"
            )

    active_phases: list[str] = raw.get("active_phases") or ["phase_0"]

    # ── phase dependency validation ────────────────────────────────────────
    # PRD §2.1: phase_2b depends on phase_2a; cannot be enabled independently.
    if "phase_2b" in active_phases and "phase_2a" not in active_phases:
        raise ValueError(
            "phase_2b depends on phase_2a — cannot enable phase_2b "
            "without phase_2a in active_phases."
        )

    # ── validate ALL entity types (all phases) ────────────────────────────
    all_entity_types: list[dict] = raw.get("entity_types") or []
    for et in all_entity_types:
        _validate_identifier(et.get("name", ""), "entity_type")

    # ── validate ALL relationship types (all phases) ──────────────────────
    all_rel_types: list[dict] = raw.get("relationship_types") or []
    for rt in all_rel_types:
        _validate_identifier(rt.get("name", ""), "relationship_type")

    # ── filter by active_phases for default extraction set ────────────────
    entity_types: list[dict] = [
        et for et in all_entity_types
        if et.get("phase", "phase_0") in active_phases
    ]
    rel_types: list[dict] = [
        rt for rt in all_rel_types
        if rt.get("phase", "phase_0") in active_phases
    ]

    norm_rules: dict[str, dict] = raw.get("normalization_rules") or {}
    doc_props: list[str] = raw.get("document_properties") or []
    kc_mapping: dict[str, str] = raw.get("knowledge_category_mapping") or {}
    doc_types: list[str] = raw.get("doc_types") or []
    extraction_scenes: dict[str, dict] = raw.get("extraction_scenes") or {}

    schema = GraphSchema(
        entity_types=entity_types,
        relationship_types=rel_types,
        _all_entity_types=tuple(all_entity_types),
        _all_relationship_types=tuple(all_rel_types),
        extraction_scenes=extraction_scenes,
        normalization_rules=norm_rules,
        document_properties=doc_props,
        knowledge_category_mapping=kc_mapping,
        doc_types=doc_types,
        active_phases=active_phases,
    )

    logger.info(
        "graph_schema_loaded",
        active_phases=active_phases,
        entity_types=[et["name"] for et in entity_types],
        relationship_types=[rt["name"] for rt in rel_types],
        total_entity_types=len(all_entity_types),
        total_rel_types=len(all_rel_types),
        extraction_scenes=list(extraction_scenes.keys()),
    )

    return schema


def _validate_identifier(name: str, kind: str) -> None:
    if not name:
        raise ValueError(f"Empty name for {kind} in graph_schema.yaml")
    if not _SAFE_IDENTIFIER_RE.match(name):
        raise ValueError(
            f"Invalid {kind} name '{name}' in graph_schema.yaml. "
            f"Must match {_SAFE_IDENTIFIER_RE.pattern}"
        )
