"""Neo4j async client wrapper for knowledge-graph operations.

Neo4j 异步客户端封装模块。
负责知识图谱的 Schema 初始化、文档子图的 MERGE 写入、
实体/关系的 CRUD 操作以及文档邻域子图查询。
支持基于 graph_schema.yaml 的动态类型校验和占位文档节点合并。
"""

from __future__ import annotations

import re
from collections import defaultdict
from typing import Any

from neo4j import AsyncDriver, AsyncGraphDatabase

from app.config import settings
from app.core.graph_schema_loader import get_schema
from app.utils.logger import get_logger

logger = get_logger(__name__)

# ── Document 枚举校验 ────────────────────────────────────────────────────────
# PRD §4.2: 不在枚举中的值不允许直接入库
VALID_DOC_STATUS = {"有效", "部分失效", "已废止", "失效", "待确认"}
VALID_DOC_ADMIN_LEVEL = {"国家级", "省级", "市级", "区县级", "乡镇街道级", "其他", "未知"}

# Cypher 标签/类型名称合法性正则：字母或下划线开头，最长 50 字符
# Regex for validating Cypher identifiers (labels, rel types) to prevent injection
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_]\w{0,49}$")


def _validate_identifier(name: str) -> str:
    """校验 Cypher 标签名或关系类型名是否合法，不合法则抛出 ValueError。
    Validate a Cypher identifier (node label or relationship type).

    Prevents Cypher injection when identifiers are interpolated into f-strings.
    """
    if not _IDENTIFIER_RE.match(name):
        raise ValueError(
            f"Invalid Cypher identifier: {name!r} — "
            f"must match {_IDENTIFIER_RE.pattern}"
        )
    return name


def _validate_doc_enum(props: dict[str, Any]) -> dict[str, Any]:
    """Validate and sanitise Document enum fields before Neo4j write.

    - status: must be in VALID_DOC_STATUS, otherwise cleared with a warning.
    - admin_level: must be in VALID_DOC_ADMIN_LEVEL, otherwise cleared.
    """
    status = props.get("status", "")
    if status and status not in VALID_DOC_STATUS:
        logger.warning("invalid_doc_status", status=status, allowed=VALID_DOC_STATUS)
        props.pop("status", None)

    admin_level = props.get("admin_level", "")
    if admin_level and admin_level not in VALID_DOC_ADMIN_LEVEL:
        logger.warning(
            "invalid_doc_admin_level",
            admin_level=admin_level,
            allowed=VALID_DOC_ADMIN_LEVEL,
        )
        props.pop("admin_level", None)

    return props


def _default_node_labels() -> set[str]:
    """Load default node labels from graph_schema.yaml (unfiltered — all phases)."""
    return get_schema().all_node_labels_unfiltered()


def _default_rel_types() -> set[str]:
    """Load default relationship types from graph_schema.yaml (unfiltered — all phases)."""
    return get_schema().all_rel_type_names_unfiltered()


# 模块级可变集合 —— 通过 reload_valid_types() 整体替换（而非就地修改）。
# CPython GIL 下对全局变量的赋值是原子操作，因此读取方无需加锁。
# 但必须使用"新 set 替换旧 set"的方式更新，不要用 .add()/.remove() 就地修改，
# 否则在 GIL 释放间隙可能读到不一致的中间状态。
# Module-level mutable sets — replaced atomically by reload_valid_types().
# Under CPython's GIL, assignment to a global is atomic, so readers need no lock.
# Always replace with a NEW set rather than mutating in-place.
_VALID_NODE_LABELS: set[str] = set()
_VALID_REL_TYPES: set[str] = set()


def reload_valid_types(
    node_labels: set[str] | None = None,
    rel_types: set[str] | None = None,
) -> None:
    """Replace the module-level valid label / type sets.

    Called by GraphAdminService after loading or mutating type definitions.
    Always includes "Document" as a built-in label.
    If called with ``None``, defaults are loaded from ``graph_schema.yaml``.
    """
    global _VALID_NODE_LABELS, _VALID_REL_TYPES  # noqa: PLW0603
    if node_labels is not None:
        _VALID_NODE_LABELS = node_labels | {"Document"}
    else:
        _VALID_NODE_LABELS = _default_node_labels()
    if rel_types is not None:
        _VALID_REL_TYPES = rel_types
    else:
        _VALID_REL_TYPES = _default_rel_types()


class Neo4jClient:
    """Async Neo4j driver wrapper with schema initialisation helpers.

    Neo4j 异步驱动封装，提供 Schema 初始化、文档子图合并、
    实体 CRUD 和邻域查询等核心图操作接口。
    """

    def __init__(self, driver: AsyncDriver) -> None:
        self._driver = driver

    @classmethod
    def from_settings(cls) -> "Neo4jClient":
        """Create a client from the global application settings."""
        driver = AsyncGraphDatabase.driver(
            settings.neo4j_uri,
            auth=(settings.neo4j_user, settings.neo4j_password),
        )
        return cls(driver)

    @property
    def driver(self) -> AsyncDriver:
        return self._driver

    async def close(self) -> None:
        await self._driver.close()

    # ── Schema initialisation ────────────────────────────────────────────

    async def init_schema(self) -> None:
        """初始化图谱 Schema：创建唯一性约束、属性索引和全文索引。

        Create uniqueness constraints, property indexes and fulltext
        indexes required by the knowledge graph.

        Entity types are loaded from ``graph_schema.yaml``; a uniqueness
        constraint on the ``key_property`` is created for each type.
        """
        schema = get_schema()

        async with self._driver.session(database=settings.neo4j_database) as session:
            # -- Document constraints (always present) --
            await session.run(
                "CREATE CONSTRAINT doc_id_unique IF NOT EXISTS "
                "FOR (d:Document) REQUIRE d.doc_id IS UNIQUE"
            )

            # -- Dynamic uniqueness constraints for ALL entity types (unfiltered) --
            # 全量建约束（含非活跃 phase），约束无数据不影响性能，
            # 确保场景化抽取写入 phase_3 实体时约束已就绪。
            for et in schema.all_entity_types_unfiltered():
                label = et["name"]
                key_prop = et.get("key_property", "name")
                # 校验标签名和属性名，防止 Cypher 注入
                # Validate identifiers before interpolation to prevent injection
                _validate_identifier(label)
                _validate_identifier(key_prop)
                constraint_name = f"{label.lower()}_{key_prop}_unique"
                cypher = (
                    f"CREATE CONSTRAINT {constraint_name} IF NOT EXISTS "
                    f"FOR (n:{label}) REQUIRE n.{key_prop} IS UNIQUE"
                )
                await session.run(cypher)

            # -- Fulltext indexes --
            await session.run(
                "CREATE FULLTEXT INDEX doc_fulltext IF NOT EXISTS "
                "FOR (d:Document) ON EACH [d.title, d.subject_summary, d.doc_code]"
            )
            await session.run(
                "CREATE FULLTEXT INDEX org_fulltext IF NOT EXISTS "
                "FOR (o:Organization) ON EACH [o.name]"
            )
            await session.run(
                "CREATE FULLTEXT INDEX matter_fulltext IF NOT EXISTS "
                "FOR (m:Matter) ON EACH [m.name, m.description]"
            )
            # 新增核心公文图全文索引
            await session.run(
                "CREATE FULLTEXT INDEX policy_fulltext IF NOT EXISTS "
                "FOR (p:Policy) ON EACH [p.name, p.summary]"
            )
            await session.run(
                "CREATE FULLTEXT INDEX task_fulltext IF NOT EXISTS "
                "FOR (t:Task) ON EACH [t.name]"
            )
            await session.run(
                "CREATE FULLTEXT INDEX project_fulltext IF NOT EXISTS "
                "FOR (p:Project) ON EACH [p.name]"
            )
            await session.run(
                "CREATE FULLTEXT INDEX system_fulltext IF NOT EXISTS "
                "FOR (s:System) ON EACH [s.name]"
            )
            await session.run(
                "CREATE FULLTEXT INDEX dataresource_fulltext IF NOT EXISTS "
                "FOR (dr:DataResource) ON EACH [dr.name]"
            )
            await session.run(
                "CREATE FULLTEXT INDEX budget_fulltext IF NOT EXISTS "
                "FOR (b:Budget) ON EACH [b.name]"
            )
            await session.run(
                "CREATE FULLTEXT INDEX indicator_fulltext IF NOT EXISTS "
                "FOR (i:Indicator) ON EACH [i.name]"
            )
            await session.run(
                "CREATE FULLTEXT INDEX industry_fulltext IF NOT EXISTS "
                "FOR (ind:Industry) ON EACH [ind.name]"
            )

            # -- Property indexes for common look-ups --
            await session.run(
                "CREATE INDEX doc_publish_date IF NOT EXISTS "
                "FOR (d:Document) ON (d.publish_date)"
            )
            await session.run(
                "CREATE INDEX doc_doc_code IF NOT EXISTS "
                "FOR (d:Document) ON (d.doc_code)"
            )
            await session.run(
                "CREATE INDEX doc_status IF NOT EXISTS "
                "FOR (d:Document) ON (d.status)"
            )
            await session.run(
                "CREATE INDEX doc_knowledge_category IF NOT EXISTS "
                "FOR (d:Document) ON (d.knowledge_category_code)"
            )
            await session.run(
                "CREATE INDEX doc_normalized_title IF NOT EXISTS "
                "FOR (d:Document) ON (d.normalized_title)"
            )

        logger.info("neo4j_schema_initialized")

    # ── Document graph merge (preferred API) ─────────────────────────────

    async def merge_document_graph(
        self,
        doc_id: str,
        metadata: dict[str, Any],
        entities: list[dict[str, Any]],
        relationships: list[dict[str, Any]],
    ) -> None:
        """将完整的文档子图合并写入 Neo4j（MERGE 幂等操作），包括文档节点、实体节点和关系。

        Merge a complete document subgraph into Neo4j (MERGE, idempotent).

        Parameters
        ----------
        doc_id:
            Unique document ID used as the Document node key.
        metadata:
            Document-level metadata written onto the Document node.
        entities:
            Each dict: ``{"label": "<NodeLabel>", "properties": {...}}``.
            Document nodes are keyed by ``doc_id``.
            All other nodes are keyed by ``name`` (or ``key_property``
            defined in ``graph_schema.yaml``).
        relationships:
            Each dict: ``{"from_label", "from_key", "to_label", "to_key",
            "type", "properties"}``.  ``from_key`` / ``to_key`` are the
            matching key values (``doc_id`` for Document nodes, ``name``
            for all others).
        """
        # Ensure valid types are loaded
        valid_labels = _VALID_NODE_LABELS or _default_node_labels()
        valid_rels = _VALID_REL_TYPES or _default_rel_types()

        # Resolve knowledge_category → knowledge_category_code
        schema = get_schema()
        kc = metadata.get("knowledge_category", "")
        kc_code = metadata.get("knowledge_category_code", "")
        if kc and not kc_code:
            kc_code = schema.knowledge_category_code(kc)

        async with self._driver.session(database=settings.neo4j_database) as session:
            # 1. Upsert the primary Document node with full metadata
            doc_props: dict[str, Any] = {
                "doc_id": doc_id,
                "title": metadata.get("title", ""),
                "normalized_title": metadata.get("normalized_title", ""),
                "doc_code": metadata.get("doc_code") or metadata.get("doc_number", ""),
                "doc_type": metadata.get("doc_type", ""),
                "status": metadata.get("status", ""),
                "publish_date": metadata.get("publish_date", ""),
                "effective_date": metadata.get("effective_date", ""),
                "expiry_date": metadata.get("expiry_date", ""),
                "admin_level": metadata.get("admin_level", ""),
                "subject_summary": metadata.get("subject_summary", ""),
                "keywords": metadata.get("keywords") or metadata.get("subject_words") or [],
                "knowledge_category": kc,
                "knowledge_category_code": kc_code,
                "source": metadata.get("source", ""),
                "is_current": metadata.get("is_current", True),
                "acl_ids": list(dict.fromkeys(metadata.get("acl_ids") or [])),
            }
            # Remove empty-string values to avoid overwriting existing data
            doc_props = {
                k: v
                for k, v in doc_props.items()
                if v != "" and (v != [] or k == "acl_ids")
            }
            doc_props["doc_id"] = doc_id  # Always keep doc_id
            # Validate enum fields (PRD §4.2)
            doc_props = _validate_doc_enum(doc_props)

            await session.run(
                "MERGE (d:Document {doc_id: $doc_id}) SET d += $props",
                doc_id=doc_id,
                props=doc_props,
            )

            # 1.5 吸收同 doc_code 的占位 Document 节点。
            #     PRD requires: redirect placeholder edges → real doc, then delete.
            #     Direct DETACH DELETE would lose relationships from OTHER documents
            #     pointing to the placeholder (those docs won't be re-processed).
            #     分三步执行：重定向入边 → 重定向出边 → 删除空占位节点，
            #     以保证其他文档指向占位节点的关系不会丢失。
            doc_code = doc_props.get("doc_code", "")
            if doc_code:
                # Step A: redirect all incoming relationships to the real doc
                await session.run(
                    """
                    MATCH (placeholder:Document)
                    WHERE placeholder.doc_code = $doc_code
                      AND placeholder.doc_id STARTS WITH 'ref:'
                      AND placeholder.doc_id <> $doc_id
                      AND coalesce(placeholder.is_placeholder, false) = true
                    WITH placeholder
                    MATCH (src)-[r]->(placeholder)
                    WITH placeholder, src, r, type(r) AS rtype, properties(r) AS rprops
                    MATCH (real:Document {doc_id: $doc_id})
                    CALL apoc.create.relationship(src, rtype, rprops, real) YIELD rel
                    DELETE r
                    """,
                    doc_code=doc_code,
                    doc_id=doc_id,
                )
                # Step B: redirect all outgoing relationships from placeholder
                await session.run(
                    """
                    MATCH (placeholder:Document)
                    WHERE placeholder.doc_code = $doc_code
                      AND placeholder.doc_id STARTS WITH 'ref:'
                      AND placeholder.doc_id <> $doc_id
                      AND coalesce(placeholder.is_placeholder, false) = true
                    WITH placeholder
                    MATCH (placeholder)-[r]->(tgt)
                    WITH placeholder, tgt, r, type(r) AS rtype, properties(r) AS rprops
                    MATCH (real:Document {doc_id: $doc_id})
                    CALL apoc.create.relationship(real, rtype, rprops, tgt) YIELD rel
                    DELETE r
                    """,
                    doc_code=doc_code,
                    doc_id=doc_id,
                )
                # Step C: delete the now-orphaned placeholder node
                await session.run(
                    """
                    MATCH (placeholder:Document)
                    WHERE placeholder.doc_code = $doc_code
                      AND placeholder.doc_id STARTS WITH 'ref:'
                      AND placeholder.doc_id <> $doc_id
                      AND coalesce(placeholder.is_placeholder, false) = true
                    DELETE placeholder
                    """,
                    doc_code=doc_code,
                    doc_id=doc_id,
                )

            # 2. 批量合并实体节点（按 label 分组，UNWIND 批量写入）
            #    Batch merge entity nodes grouped by label to reduce N+1 round-trips
            node_groups: dict[str, list[dict[str, Any]]] = defaultdict(list)
            for entity in entities:
                label = entity.get("label", "")
                props = entity.get("properties") or {}
                if label not in valid_labels:
                    logger.warning("unknown_node_label", label=label)
                    continue
                node_groups[label].append(props)

            for label, props_list in node_groups.items():
                await _batch_merge_nodes(session, label, props_list)

            # 3. 批量合并关系（按 from_label/to_label/rel_type 三元组分组）
            #    Batch merge relationships grouped by (from_label, to_label, rel_type)
            rel_groups: dict[tuple[str, str, str], list[dict[str, Any]]] = defaultdict(list)
            for rel in relationships:
                from_label = rel.get("from_label", "")
                to_label = rel.get("to_label", "")
                rel_type = rel.get("type", "")
                if (
                    from_label not in valid_labels
                    or to_label not in valid_labels
                    or rel_type not in valid_rels
                ):
                    logger.warning(
                        "invalid_relationship",
                        from_label=from_label,
                        to_label=to_label,
                        rel_type=rel_type,
                    )
                    continue
                rel_groups[(from_label, to_label, rel_type)].append(rel)

            for (from_label, to_label, rel_type), rel_list in rel_groups.items():
                await _batch_merge_rels(session, from_label, to_label, rel_type, rel_list)

        logger.info(
            "document_graph_merged",
            doc_id=doc_id,
            entities=len(entities),
            relationships=len(relationships),
        )

    # ── Entity merging (legacy API, kept for backward compatibility) ──────

    async def merge_entities(
        self,
        doc_id: str,
        entities: list[dict[str, Any]],
        relationships: list[dict[str, Any]],
    ) -> None:
        """Merge extracted entities and relationships into the knowledge graph.

        Parameters
        ----------
        doc_id:
            The source document identifier.
        entities:
            Each dict must have ``label`` (e.g. Organization) and ``properties``
            containing at least the uniqueness key (e.g. ``name``).
        relationships:
            Each dict must have ``from_label``, ``from_key``, ``to_label``,
            ``to_key``, ``type`` (relationship type) and optionally ``properties``.
        """
        async with self._driver.session(database=settings.neo4j_database) as session:
            # Ensure the Document node exists.
            await session.run(
                "MERGE (d:Document {doc_id: $doc_id})",
                doc_id=doc_id,
            )

            for entity in entities:
                label = entity["label"]
                props = entity["properties"]
                cypher = (
                    f"MERGE (n:{label} {{{_key_clause(label, props)}}})\n"
                    f"SET n += $props"
                )
                await session.run(cypher, props=props)

            for rel in relationships:
                from_label = rel["from_label"]
                from_key = rel["from_key"]
                to_label = rel["to_label"]
                to_key = rel["to_key"]
                rel_type = rel["type"]
                rel_props = rel.get("properties", {})
                # 分别传入 param_prefix 以生成正确的 $from_key 和 $to_key 引用
                cypher = (
                    f"MATCH (a:{from_label} {{{_match_clause(from_label, from_key, 'from')}}})\n"
                    f"MATCH (b:{to_label} {{{_match_clause(to_label, to_key, 'to')}}})\n"
                    f"MERGE (a)-[r:{rel_type}]->(b)\n"
                    f"SET r += $rel_props"
                )
                await session.run(
                    cypher,
                    from_key=from_key,
                    to_key=to_key,
                    rel_props=rel_props,
                )

        logger.info("neo4j_entities_merged", doc_id=doc_id, count=len(entities))

    # ── Query helpers ────────────────────────────────────────────────────

    async def query_document_graph(
        self,
        doc_id: str,
        *,
        max_depth: int = 2,
        acl_tokens: list[str] | None = None,
    ) -> dict[str, Any]:
        """查询文档的邻域子图，返回指定深度内的所有节点和边。

        Return nodes and edges for a document's neighbourhood sub-graph.

        Returns a dict with ``nodes`` (list of dicts) and ``edges`` (list of
        dicts with ``source``, ``target``, ``type``).

        NOTE: Neo4j does not allow parameters in variable-length path ranges
        ``[*1..$max_depth]``.  The depth is embedded as a literal.  All node
        and relationship data is returned as scalar values (elementId, labels,
        properties) so the async driver never needs to serialise Node /
        Relationship objects.
        """
        depth_safe = max(1, min(int(max_depth), 3))
        doc_visibility = "coalesce(node.is_placeholder, false) = false"
        if acl_tokens is not None:
            doc_visibility = (
                "(coalesce(node.is_placeholder, false) = false "
                "AND node.acl_ids IS NOT NULL "
                "AND (size(coalesce(node.acl_ids, [])) = 0 "
                "OR any(token IN coalesce(node.acl_ids, []) WHERE token IN $acl_tokens)))"
            )
        cypher = (
            f"MATCH path = (d:Document {{doc_id: $doc_id}})-[*1..{depth_safe}]-(neighbor) "
            f"WHERE all(node IN nodes(path) WHERE NOT node:Document OR {doc_visibility}) "
            "UNWIND relationships(path) AS rel "
            "WITH DISTINCT rel "
            "WITH "
            "  collect(DISTINCT {eid: elementId(startNode(rel)), lbs: labels(startNode(rel)), props: properties(startNode(rel))}) "
            "  + collect(DISTINCT {eid: elementId(endNode(rel)),   lbs: labels(endNode(rel)),   props: properties(endNode(rel))}) "
            "  AS all_nodes, "
            "  collect(DISTINCT {src: elementId(startNode(rel)), tgt: elementId(endNode(rel)), "
            "                    typ: type(rel), props: properties(rel)}) AS all_rels "
            "RETURN all_nodes AS nodes, all_rels AS rels"
        )
        async with self._driver.session(database=settings.neo4j_database) as session:
            params: dict[str, Any] = {"doc_id": doc_id}
            if acl_tokens is not None:
                params["acl_tokens"] = acl_tokens
            result = await session.run(cypher, **params)
            record = await result.single()

        if record is None:
            return {"nodes": [], "edges": []}

        # Deduplicate nodes by eid (collect DISTINCT on maps isn't true-distinct)
        seen_nodes: set[str] = set()
        nodes: list[dict[str, Any]] = []
        for n in (record["nodes"] or []):
            eid = n["eid"]
            if eid and eid not in seen_nodes:
                seen_nodes.add(eid)
                nodes.append({
                    "id": eid,
                    "labels": list(n["lbs"]),
                    "properties": dict(n["props"] or {}),
                })

        seen_edges: set[tuple] = set()
        edges: list[dict[str, Any]] = []
        for r in (record["rels"] or []):
            key = (r["src"], r["tgt"], r["typ"])
            if key not in seen_edges:
                seen_edges.add(key)
                edges.append({
                    "source": r["src"],
                    "target": r["tgt"],
                    "type": r["typ"],
                    "properties": dict(r["props"] or {}),
                })

        return {"nodes": nodes, "edges": edges}

    async def get_node_count(self) -> int:
        """Return the total number of nodes in the knowledge graph."""
        async with self._driver.session(database=settings.neo4j_database) as session:
            result = await session.run("MATCH (n) RETURN count(n) AS cnt")
            record = await result.single()
            return record["cnt"] if record else 0

    async def delete_document_graph(self, doc_id: str) -> dict[str, int]:
        """删除文档节点、关系，以及因此变为孤立的实体节点。

        Steps:
        1. 收集 Document 直连的所有实体节点 ID
        2. DETACH DELETE Document 节点
        3. 对原直连实体，若不再有任何 Document 引用则删除（清理孤立节点）

        Returns
        -------
        dict
            ``{"deleted_nodes": int}`` — 删除的节点总数（含 Document + 孤立实体）。
        """
        cypher = """
        MATCH (d:Document {doc_id: $doc_id})
        OPTIONAL MATCH (d)--(entity)
        WHERE NOT entity:Document
        WITH d, collect(DISTINCT elementId(entity)) AS entity_ids
        DETACH DELETE d
        WITH entity_ids
        UNWIND entity_ids AS eid
        MATCH (e) WHERE elementId(e) = eid
        WHERE NOT EXISTS { MATCH (e)--(other:Document) }
        DETACH DELETE e
        RETURN count(e) AS orphans_deleted
        """
        async with self._driver.session(database=settings.neo4j_database) as session:
            result = await session.run(cypher, doc_id=doc_id)
            record = await result.single()
            orphans = record["orphans_deleted"] if record else 0

        total = 1 + orphans  # Document node + orphaned entities
        logger.info(
            "neo4j_document_deleted",
            doc_id=doc_id,
            deleted_nodes=total,
            orphans_deleted=orphans,
        )
        return {"deleted_nodes": total}

    async def delete_all_graph(self) -> dict[str, int]:
        """删除图数据库中的所有节点和关系。"""
        cypher = """
        MATCH (n)
        WITH n LIMIT 10000
        DETACH DELETE n
        RETURN count(n) AS deleted
        """
        total_deleted = 0
        async with self._driver.session(database=settings.neo4j_database) as session:
            while True:
                result = await session.run(cypher)
                record = await result.single()
                batch = record["deleted"] if record else 0
                total_deleted += batch
                if batch < 10000:
                    break

        logger.info("neo4j_all_deleted", deleted_nodes=total_deleted)
        return {"deleted_nodes": total_deleted}

    async def update_node(self, entity_id: str, props: dict[str, Any]) -> dict[str, Any] | None:
        """Update properties on a node identified by elementId."""
        cypher = (
            "MATCH (n) WHERE elementId(n) = $eid "
            "SET n += $props "
            "RETURN elementId(n) AS id, labels(n) AS labels, properties(n) AS props"
        )
        async with self._driver.session(database=settings.neo4j_database) as session:
            result = await session.run(cypher, eid=entity_id, props=props)
            record = await result.single()
        if record is None:
            return None
        return {
            "id": record["id"],
            "labels": list(record["labels"]),
            "properties": dict(record["props"]),
        }

    async def delete_node(self, entity_id: str) -> dict[str, int]:
        """DETACH DELETE a node by elementId, returning the number of deleted relationships."""
        cypher = (
            "MATCH (n) WHERE elementId(n) = $eid "
            "WITH n, size([(n)-[r]-() | r]) AS rel_count "
            "DETACH DELETE n "
            "RETURN rel_count"
        )
        async with self._driver.session(database=settings.neo4j_database) as session:
            result = await session.run(cypher, eid=entity_id)
            record = await result.single()
        return {"deleted_relationships": record["rel_count"] if record else 0}

    async def delete_relationship_by_id(self, rel_id: str) -> bool:
        """Delete a single relationship by elementId. Returns True if deleted."""
        cypher = (
            "MATCH ()-[r]->() WHERE elementId(r) = $rid "
            "DELETE r RETURN count(r) AS cnt"
        )
        async with self._driver.session(database=settings.neo4j_database) as session:
            result = await session.run(cypher, rid=rel_id)
            record = await result.single()
        return bool(record and record["cnt"] > 0)

    async def get_entity(self, entity_id: str) -> dict[str, Any] | None:
        """Retrieve a single entity node by its element id."""
        cypher = (
            "MATCH (n) WHERE elementId(n) = $eid "
            "OPTIONAL MATCH (n)-[r]-(m) "
            "RETURN n, collect(DISTINCT {rel: r, neighbor: m}) AS connections"
        )
        async with self._driver.session(database=settings.neo4j_database) as session:
            result = await session.run(cypher, eid=entity_id)
            record = await result.single()

        if record is None:
            return None

        node = record["n"]
        connections = []
        for conn in record["connections"]:
            if conn["rel"] is not None:
                connections.append({
                    "type": conn["rel"].type,
                    "neighbor_id": conn["neighbor"].element_id,
                    "neighbor_labels": list(conn["neighbor"].labels),
                    "neighbor_name": dict(conn["neighbor"]).get("name", ""),
                })

        return {
            "id": node.element_id,
            "labels": list(node.labels),
            "properties": dict(node),
            "connections": connections,
        }


# ── Private helpers ──────────────────────────────────────────────────────────


def _key_prop(label: str) -> str:
    """Return the primary key property name for a given node label.

    Uses ``key_property`` from ``graph_schema.yaml`` if available,
    otherwise defaults to ``doc_id`` for Document and ``name`` for others.
    """
    if label == "Document":
        return "doc_id"
    try:
        schema = get_schema()
        # 使用全量映射，确保 phase_3 等非活跃 phase 的类型也能正确解析 key_property
        # Use unfiltered map so that non-active-phase types (e.g. Person with
        # key_property=person_id in phase_3) resolve correctly.
        et_map = schema.all_entity_type_map_unfiltered()
        if label in et_map:
            return et_map[label].get("key_property", "name")
    except Exception:
        # 从 schema 获取 key_property 失败，回退到默认值 "name"
        logger.debug("Failed to determine key property, falling back to 'name'", exc_info=True)
    return "name"


async def _batch_merge_nodes(session: Any, label: str, props_list: list[dict[str, Any]]) -> None:
    """批量 MERGE 同一 label 的节点，使用 UNWIND 减少网络往返。

    Batch merge nodes of the same label using UNWIND to minimise round-trips.
    Filters out items with empty/None key values and falls back to per-item
    execution if the entire batch fails.
    """
    if not props_list:
        return

    # 校验标签名合法性 / Validate label before Cypher interpolation
    _validate_identifier(label)
    key_prop = _key_prop(label)
    _validate_identifier(key_prop)

    # 分组前过滤坏数据：key_val 为 None 或空字符串的项不进入批次
    # Filter out bad data: skip items where key value is None or empty
    batch: list[dict[str, Any]] = []
    for props in props_list:
        key_val = props.get(key_prop, "")
        if not key_val:
            logger.warning(
                "batch_node_skip_empty_key",
                label=label,
                key_prop=key_prop,
                props_keys=list(props.keys()),
            )
            continue
        batch.append({"key_val": key_val, "props": props})

    if not batch:
        return

    cypher = (
        f"UNWIND $batch AS item "
        f"MERGE (n:{label} {{{key_prop}: item.key_val}}) "
        f"SET n += item.props"
    )

    try:
        await session.run(cypher, batch=batch)
    except Exception as exc:
        # 整批失败回退：逐条执行并逐个记录 warning
        # Batch failed — fall back to per-item execution with individual warnings
        logger.warning(
            "batch_node_merge_failed_fallback",
            label=label,
            batch_size=len(batch),
            error=str(exc),
        )
        for item in batch:
            await _merge_node(session, label, item["props"])


async def _batch_merge_rels(
    session: Any,
    from_label: str,
    to_label: str,
    rel_type: str,
    rel_list: list[dict[str, Any]],
) -> None:
    """批量 MERGE 同一类型的关系，使用 UNWIND 减少网络往返。

    Batch merge relationships of the same (from_label, to_label, rel_type)
    using UNWIND to minimise round-trips.
    Falls back to per-item execution if the entire batch fails.
    """
    if not rel_list:
        return

    # 校验标签名和关系类型名 / Validate all identifiers
    _validate_identifier(from_label)
    _validate_identifier(to_label)
    _validate_identifier(rel_type)

    from_key_prop = _key_prop(from_label)
    to_key_prop = _key_prop(to_label)
    _validate_identifier(from_key_prop)
    _validate_identifier(to_key_prop)

    # 过滤坏数据：from_key 或 to_key 为 None/空的项不进入批次
    # Filter out bad data: skip items where from_key or to_key is None/empty
    batch: list[dict[str, Any]] = []
    for rel in rel_list:
        from_key = rel.get("from_key", "")
        to_key = rel.get("to_key", "")
        if not from_key or not to_key:
            logger.warning(
                "batch_rel_skip_empty_key",
                from_label=from_label,
                to_label=to_label,
                rel_type=rel_type,
                from_key=from_key,
                to_key=to_key,
            )
            continue
        batch.append({
            "from_key": from_key,
            "to_key": to_key,
            "props": rel.get("properties") or {},
        })

    if not batch:
        return

    cypher = (
        f"UNWIND $batch AS item "
        f"MATCH (a:{from_label} {{{from_key_prop}: item.from_key}}) "
        f"MATCH (b:{to_label} {{{to_key_prop}: item.to_key}}) "
        f"MERGE (a)-[r:{rel_type}]->(b) "
        f"SET r += item.props"
    )

    try:
        await session.run(cypher, batch=batch)
    except Exception as exc:
        # 整批失败回退：逐条执行并逐个记录 warning
        # Batch failed — fall back to per-item execution with individual warnings
        logger.warning(
            "batch_rel_merge_failed_fallback",
            from_label=from_label,
            to_label=to_label,
            rel_type=rel_type,
            batch_size=len(batch),
            error=str(exc),
        )
        for item_rel in rel_list:
            await _merge_rel(session, item_rel)


async def _merge_node(session: Any, label: str, props: dict[str, Any]) -> None:
    """Merge a single entity node.  Label must be pre-validated."""
    try:
        # 二次校验标签名合法性，确保 f-string 插入安全
        # Double-check label validity before Cypher interpolation
        _validate_identifier(label)
        key = _key_prop(label)
        key_val = props.get(key, "")
        if not key_val:
            return
        # Label is pre-validated — safe to interpolate
        await session.run(
            f"MERGE (n:{label} {{{key}: $key_val}}) SET n += $props",
            key_val=key_val,
            props=props,
        )
    except Exception as exc:
        logger.warning("node_merge_failed", label=label, error=str(exc))


async def _merge_rel(session: Any, rel: dict[str, Any]) -> None:
    """Merge a single relationship.  Labels and type must be pre-validated."""
    from_label = rel["from_label"]
    from_key = rel["from_key"]
    to_label = rel["to_label"]
    to_key = rel["to_key"]
    rel_type = rel["type"]
    rel_props = rel.get("properties") or {}

    if not from_key or not to_key:
        return

    # 校验标签名和关系类型名，防止 Cypher 注入
    # Validate all identifiers before Cypher interpolation
    _validate_identifier(from_label)
    _validate_identifier(to_label)
    _validate_identifier(rel_type)

    from_prop = _key_prop(from_label)
    to_prop = _key_prop(to_label)

    # Label and rel_type are pre-validated; property names are fixed strings.
    cypher = (
        f"MATCH (a:{from_label} {{{from_prop}: $from_key}})\n"
        f"MATCH (b:{to_label} {{{to_prop}: $to_key}})\n"
        f"MERGE (a)-[r:{rel_type}]->(b)\n"
        f"SET r += $rel_props"
    )
    try:
        await session.run(
            cypher,
            from_key=from_key,
            to_key=to_key,
            rel_props=rel_props,
        )
    except Exception as exc:
        logger.warning(
            "rel_merge_failed",
            from_label=from_label,
            to_label=to_label,
            rel_type=rel_type,
            error=str(exc),
        )


# ── Legacy helpers ────────────────────────────────────────────────────────────

def _key_clause(label: str, props: dict[str, Any]) -> str:
    """Build the MERGE key clause based on the node label's uniqueness key."""
    key = _key_prop(label)
    return f"{key}: $props.{key}"


def _match_clause(label: str, key_value: Any, param_prefix: str = "from") -> str:  # noqa: ARG001
    """Build a MATCH clause fragment for a node look-up by its primary key.

    根据 param_prefix 生成正确的参数引用（from_key 或 to_key），
    修复之前始终返回 $from_key 而忽略 $to_key 的逻辑缺陷。
    """
    key = _key_prop(label)
    return f"{key}: ${param_prefix}_key"
