"""Celery tasks for asynchronous knowledge-graph construction.

知识图谱异步构建的 Celery 任务模块。
提供单文档图谱构建、批量图谱构建、全量重建等任务。
每个任务创建临时的 LLM / Neo4j 客户端，通过 GraphBuilder
完成实体抽取→规范化→Neo4j 写入的完整流程。
"""

from __future__ import annotations

import asyncio
from typing import Any

from app.tasks.celery_app import celery_app
from app.utils.logger import get_logger

logger = get_logger(__name__)


# ---------------------------------------------------------------------------
# Single-document graph-build task
# ---------------------------------------------------------------------------


@celery_app.task(
    bind=True,
    name="app.tasks.graph_task.build_document_graph_task",
    max_retries=2,
    default_retry_delay=60,
    soft_time_limit=300,   # 5-minute soft limit
    time_limit=360,         # 6-minute hard limit
)
def build_document_graph_task(
    self,
    doc_id: str,
    metadata: dict[str, Any],
    content: str,
) -> dict[str, Any]:
    """Extract entities and build knowledge graph for a single document.

    Parameters
    ----------
    doc_id:
        Unique document identifier.
    metadata:
        Document metadata (title, doc_number, issuing_org, etc.).
    content:
        Full-text content of the document.

    Returns
    -------
    dict
        ``{"status": "completed"|"failed", "entity_count": int, ...}``
    """
    loop = asyncio.new_event_loop()
    try:
        result = loop.run_until_complete(_run_graph_build(doc_id, metadata, content))
        if result.get("status") == "failed":
            raise RuntimeError(result.get("error", "Graph build failed"))
        return result
    except (ValueError, KeyError, TypeError) as exc:
        # 永久性错误（数据格式问题等），不重试，直接失败
        logger.error(
            "graph_task_permanent_failure",
            doc_id=doc_id,
            attempt=self.request.retries,
            error=str(exc),
        )
        return {"status": "failed", "error": str(exc), "retryable": False}
    except Exception as exc:
        # 临时性错误（网络超时、连接中断等），触发重试
        logger.error(
            "graph_task_failed",
            doc_id=doc_id,
            attempt=self.request.retries,
            error=str(exc),
        )
        raise self.retry(exc=exc, countdown=60 * (self.request.retries + 1))
    finally:
        loop.close()


async def _run_graph_build(
    doc_id: str,
    metadata: dict[str, Any],
    content: str,
) -> dict[str, Any]:
    """Async implementation — creates fresh clients, runs pipeline, cleans up."""
    from app.core.graph_builder import GraphBuilder
    from app.infrastructure.llm_client import LLMClient
    from app.infrastructure.neo4j_client import Neo4jClient

    llm_client = LLMClient()
    neo4j_client = Neo4jClient.from_settings()
    try:
        builder = GraphBuilder(llm_client, neo4j_client)
        scene_type = (
            metadata.get("document_scene_type")
            or metadata.get("knowledge_category_code")
            or ""
        )
        return await builder.build_graph(doc_id, metadata, content, scene_type=scene_type)
    finally:
        await llm_client.close()
        await neo4j_client.close()


# ---------------------------------------------------------------------------
# Bulk graph-build task (processes a batch of doc_ids)
# ---------------------------------------------------------------------------


@celery_app.task(
    name="app.tasks.graph_task.bulk_build_graph_task",
    soft_time_limit=3600,
    time_limit=3660,
)
def bulk_build_graph_task(doc_ids: list[str]) -> dict[str, Any]:
    """Trigger individual graph-build tasks for each doc_id in the batch.

    Dispatches a ``build_document_graph_task`` for each document whose
    text and metadata are fetched from Elasticsearch.

    Parameters
    ----------
    doc_ids:
        List of document IDs to process.

    Returns
    -------
    dict
        ``{"dispatched": int, "failed_fetch": list[str]}``
    """
    loop = asyncio.new_event_loop()
    try:
        return loop.run_until_complete(_dispatch_bulk(doc_ids))
    finally:
        loop.close()


async def _dispatch_bulk(doc_ids: list[str]) -> dict[str, Any]:
    """Fetch metadata + content from ES and enqueue individual graph tasks."""
    from opensearchpy import AsyncOpenSearch

    from app.config import settings

    es_kwargs: dict[str, Any] = {
        "hosts": [settings.es_host],
        "timeout": 60,
    }
    if settings.es_username:
        es_kwargs["http_auth"] = (settings.es_username, settings.es_password)

    es = AsyncOpenSearch(**es_kwargs)
    failed: list[str] = []
    dispatched = 0

    try:
        for doc_id in doc_ids:
            try:
                meta_resp = await es.get(index=settings.es_meta_index, id=doc_id)
                meta: dict[str, Any] = meta_resp["_source"]

                content = await _reconstruct_content(es, doc_id, settings)
                if not content:
                    logger.warning("bulk_graph_empty_content", doc_id=doc_id)
                    failed.append(doc_id)
                    continue

                build_document_graph_task.delay(doc_id, meta, content)
                dispatched += 1
                logger.info("bulk_graph_dispatched", doc_id=doc_id)

            except Exception as exc:
                logger.error("bulk_graph_fetch_error", doc_id=doc_id, error=str(exc))
                failed.append(doc_id)
    finally:
        await es.close()

    return {"dispatched": dispatched, "failed_fetch": failed}


# ---------------------------------------------------------------------------
# Rebuild tasks (admin graph management)
# ---------------------------------------------------------------------------


@celery_app.task(
    bind=True,
    name="app.tasks.graph_task.rebuild_document_graphs_task",
    soft_time_limit=3600,
    time_limit=3660,
    track_started=True,
)
def rebuild_document_graphs_task(
    self, doc_ids: list[str]
) -> dict[str, Any]:
    """Rebuild knowledge graphs for a list of documents.

    For each doc: deletes existing graph → fetches meta+content from ES → rebuilds.
    Reports progress via update_state.
    """
    loop = asyncio.new_event_loop()
    try:
        return loop.run_until_complete(_run_rebuild(self, doc_ids))
    finally:
        loop.close()


async def _run_rebuild(self_task: Any, doc_ids: list[str]) -> dict[str, Any]:
    from opensearchpy import AsyncOpenSearch

    from app.config import settings
    from app.core.graph_builder import GraphBuilder
    from app.infrastructure.llm_client import LLMClient
    from app.infrastructure.neo4j_client import Neo4jClient

    es_kwargs: dict[str, Any] = {
        "hosts": [settings.es_host],
        "timeout": 60,
    }
    if settings.es_username:
        es_kwargs["http_auth"] = (settings.es_username, settings.es_password)

    es = AsyncOpenSearch(**es_kwargs)
    llm_client = LLMClient()
    neo4j_client = Neo4jClient.from_settings()
    builder = GraphBuilder(llm_client, neo4j_client)

    total = len(doc_ids)
    completed = 0
    failed_ids: list[str] = []

    try:
        for i, doc_id in enumerate(doc_ids):
            self_task.update_state(
                state="PROCESSING",
                meta={
                    "progress": round(i / total, 2),
                    "current": i,
                    "total": total,
                    "current_doc_id": doc_id,
                },
            )
            try:
                # Delete existing graph for this document
                await neo4j_client.delete_document_graph(doc_id)

                # Fetch metadata
                meta_resp = await es.get(index=settings.es_meta_index, id=doc_id)
                raw = meta_resp if isinstance(meta_resp, dict) else meta_resp.body
                meta = raw["_source"]

                # Reconstruct content
                content = await _reconstruct_content(es, doc_id, settings)
                if not content:
                    logger.warning("rebuild_empty_content", doc_id=doc_id)
                    failed_ids.append(doc_id)
                    continue

                # Build graph (with scene_type from metadata)
                scene_type = (
                    meta.get("document_scene_type")
                    or meta.get("knowledge_category_code")
                    or ""
                )
                result = await builder.build_graph(doc_id, meta, content, scene_type=scene_type)
                if result.get("status") == "completed":
                    completed += 1
                else:
                    failed_ids.append(doc_id)

            except Exception as exc:
                logger.error("rebuild_doc_failed", doc_id=doc_id, error=str(exc))
                failed_ids.append(doc_id)
    finally:
        await es.close()
        await llm_client.close()
        await neo4j_client.close()

    return {
        "total": total,
        "completed": completed,
        "failed": failed_ids,
        "status": "completed",
    }


@celery_app.task(
    name="app.tasks.graph_task.rebuild_all_graph_task",
    soft_time_limit=7200,
    time_limit=7260,
)
def rebuild_all_graph_task() -> dict[str, Any]:
    """Rebuild graphs for ALL completed documents in ES."""
    loop = asyncio.new_event_loop()
    try:
        return loop.run_until_complete(_run_rebuild_all())
    finally:
        loop.close()


async def _run_rebuild_all() -> dict[str, Any]:
    from opensearchpy import AsyncOpenSearch

    from app.config import settings

    es_kwargs: dict[str, Any] = {
        "hosts": [settings.es_host],
        "timeout": 60,
    }
    if settings.es_username:
        es_kwargs["http_auth"] = (settings.es_username, settings.es_password)

    # 硬编码上限 10000：如果文档数接近此值，可能遗漏部分文档。
    # 后续可改为 scroll API 或 search_after 分页以支持更大规模数据。
    _REBUILD_ALL_SIZE_LIMIT = 10000

    es = AsyncOpenSearch(**es_kwargs)
    try:
        resp = await es.search(
            index=settings.es_meta_index,
            body={
                "size": _REBUILD_ALL_SIZE_LIMIT,
                "query": {"term": {"status": "completed"}},
                "_source": False,
            },
        )
        raw = resp if isinstance(resp, dict) else resp.body
        doc_ids = [hit["_id"] for hit in raw["hits"]["hits"]]

        # 当结果数接近上限时记录警告，提醒运维人员注意数据量
        total_hits = raw.get("hits", {}).get("total", {}).get("value", len(doc_ids))
        if total_hits >= _REBUILD_ALL_SIZE_LIMIT:
            logger.warning(
                "rebuild_all_size_limit_reached",
                total_hits=total_hits,
                limit=_REBUILD_ALL_SIZE_LIMIT,
            )
    finally:
        await es.close()

    if not doc_ids:
        return {"total": 0, "dispatched": 0}

    # Dispatch the rebuild task
    task = rebuild_document_graphs_task.delay(doc_ids)
    return {"total": len(doc_ids), "dispatched_task_id": task.id}


# ---------------------------------------------------------------------------
# Content reconstruction helper
# ---------------------------------------------------------------------------


async def _reconstruct_content(
    es: Any, doc_id: str, task_settings: Any
) -> str:
    """Reconstruct document text from ES chunk index (ordered by chunk_index).

    New schema stores chunks by content_hash (not doc_id).
    We first look up the content_hash from the meta index, then query chunks.

    从 ES chunk 索引中按顺序拼接出完整文档文本。
    由于新 Schema 使用 content_hash 而非 doc_id 关联 chunk，
    需要先从 meta 索引获取 content_hash 再查询对应的 chunk。

    参数名使用 task_settings 以避免遮盖模块级 settings 导入。
    """
    # Get content_hash from meta
    try:
        meta_resp = await es.get(
            index=task_settings.es_meta_index,
            id=doc_id,
            _source=["content_hash"],
        )
        raw = meta_resp if isinstance(meta_resp, dict) else meta_resp.body
        content_hash = raw.get("_source", {}).get("content_hash", "")
    except Exception:
        content_hash = ""

    if not content_hash:
        logger.warning("reconstruct_no_content_hash", doc_id=doc_id)
        return ""

    resp = await es.search(
        index=task_settings.es_chunk_index,
        body={
            "query": {"term": {"content_hash": content_hash}},
            "sort": [{"chunk_index": {"order": "asc"}}],
            "size": 2000,
            "_source": ["content"],
        },
    )
    raw_resp = resp if isinstance(resp, dict) else resp.body
    chunks = [hit["_source"].get("content", "") for hit in raw_resp["hits"]["hits"]]
    return "\n".join(chunks)
