"""Celery tasks for research execution.

Each task runs the async engine pipeline inside a new event loop,
creating its own short-lived clients — following the same pattern
as ingest_task.py (lines 28-37, 78-119).

研究执行 Celery 任务模块。
- execute_research_run_task: 执行完整研究
- execute_section_rerun_task: 执行章节重跑
"""

from __future__ import annotations

import asyncio
import json
from datetime import datetime, timezone
from typing import Any

from celery.utils.log import get_task_logger

from app.tasks.celery_app import celery_app

logger = get_task_logger(__name__)

# Stream key prefix and TTL
_STREAM_PREFIX = "research:run:"
_STREAM_TTL = 3600  # 1 hour

# Reference dedup key — matches frontend getReferenceKey (research.ts:71)
_GUIDE_GROUP = "guide"


def _run_async(coro):
    """Run an async coroutine in a new event loop (Celery worker compatible)."""
    loop = asyncio.new_event_loop()
    try:
        return loop.run_until_complete(coro)
    finally:
        loop.close()


def _utcnow() -> datetime:
    return datetime.now(timezone.utc)


def _ref_dedup_key(chunk_payload: dict) -> str:
    sg = chunk_payload.get("source_group", "")
    pid = chunk_payload.get("profile_id")
    doc_id = chunk_payload.get("doc_id", "")
    if sg == _GUIDE_GROUP:
        return f"guide:{pid or doc_id}"
    return f"doc:{doc_id}"


# ---------------------------------------------------------------------------
# Full research run
# ---------------------------------------------------------------------------

@celery_app.task(
    name="tasks.execute_research_run",
    bind=True,
    max_retries=0,
    acks_late=True,
    track_started=True,
)
def execute_research_run_task(
    self,
    run_id: str,
    record_id: str,
    user_id: str,
    acl_tokens: list[str],
) -> dict[str, Any]:
    logger.info("research_run_start run_id=%s record_id=%s", run_id, record_id)
    return _run_async(_execute_full_run(run_id, record_id, user_id, acl_tokens))


async def _execute_full_run(
    run_id: str,
    record_id: str,
    user_id: str,
    acl_tokens: list[str],
) -> dict[str, Any]:
    """Worker-internal: execute full research and persist results."""

    from app.config import settings
    from app.core.embedding import EmbeddingService
    from app.core.graph_query_service import GraphQueryService
    from app.core.permission import PermissionContext
    from app.core.research_engine import ResearchEngine
    from app.api.schemas.research import ResearchTask
    from app.infrastructure.embedding_client import EmbeddingClient
    from app.infrastructure.es_client import ESClient
    from app.infrastructure.llm_client import LLMClient
    from app.infrastructure.mysql_client import MySQLClient
    from app.infrastructure.neo4j_client import Neo4jClient
    from app.infrastructure.redis_client import RedisClient
    from app.infrastructure.research_record_store import (
        ResearchRecordRunStore,
        ResearchRecordStore,
    )
    from app.infrastructure.session_store import build_research_session_store

    # -- Build short-lived clients (C13) --
    redis_client = RedisClient.from_settings()
    mysql_client = await MySQLClient.from_settings() if settings.mysql_enabled else None

    es_client = ESClient.from_settings(redis_client=redis_client)

    neo4j_client = Neo4jClient.from_settings()
    embedding_client = EmbeddingClient()
    llm_client = LLMClient()

    # Build stores
    record_store = ResearchRecordStore(mysql_client)
    run_store = ResearchRecordRunStore(mysql_client)

    # Build engine
    embedding_svc = EmbeddingService(embedding_client)
    graph_svc = GraphQueryService(neo4j_client) if neo4j_client else None
    session_store = build_research_session_store(
        redis_client=redis_client, mysql_client=mysql_client,
    )
    engine = ResearchEngine(
        es_client=es_client,
        embedding_service=embedding_svc,
        graph_service=graph_svc,
        llm_client=llm_client,
        session_store=session_store,
    )

    # Rebuild perm (C11)
    perm = PermissionContext(user_id=user_id, acl_tokens=acl_tokens)

    stream_key = f"{_STREAM_PREFIX}{run_id}:events"

    try:
        # Update states (C16)
        await run_store.update_run(run_id, status="running", started_at=_utcnow())
        await record_store.update_record(record_id, user_id, status="running")

        # Load record
        record = await record_store.get_record(record_id, user_id)
        if not record:
            raise ValueError(f"Record {record_id} not found")

        task = ResearchTask(**(record["task_json"] or {}))
        plan_json = record.get("plan_json") or {}

        from app.api.schemas.research import ResearchPlan
        plan = ResearchPlan(**plan_json)

        # Extract seed docs (C2)
        from app.core.research_record_service import ResearchRecordService
        seed_doc_ids = ResearchRecordService.extract_seed_doc_ids(record) or None

        # Accumulators
        report_acc: dict[str, Any] = {
            "executive_summary": "",
            "findings": [],
            "conflicts": [],
            "open_questions": [],
            "sections": [],
            "one_page_summary": "",
            "recommended_next_steps": [],
            "citation_map": [],
        }
        references_acc: dict[str, dict] = {}  # dedup key → ref dict
        had_error = False

        async for chunk in engine.run_deep_research(
            task, plan, perm, seed_doc_ids=seed_doc_ids,
        ):
            # Write to Redis Stream
            payload = json.dumps(chunk.model_dump(exclude_none=True), ensure_ascii=False)
            await redis_client.xadd(stream_key, {"data": payload})

            # Aggregate report fields
            ct = chunk.type
            p = chunk.payload or {}

            if ct == "summary":
                report_acc["executive_summary"] = chunk.content or ""
            elif ct == "finding":
                report_acc["findings"].append({
                    "title": chunk.title or "",
                    "content": chunk.content or "",
                    "strength": chunk.strength or "medium",
                    "source_doc_ids": p.get("source_doc_ids", []),
                })
            elif ct == "conflict":
                report_acc["conflicts"].append({
                    "title": chunk.title or "",
                    "content": chunk.content or "",
                    "severity": chunk.severity or "medium",
                    "source_doc_ids": p.get("source_doc_ids", []),
                })
            elif ct == "open_question":
                report_acc["open_questions"].append({
                    "question": chunk.title or "",
                    "reason": chunk.content or "",
                })
            elif ct == "section":
                report_acc["sections"].append({
                    "title": chunk.title or "",
                    "summary": p.get("summary", ""),
                    "content": chunk.content or "",
                    "source_doc_ids": p.get("source_doc_ids", []),
                })
            elif ct == "report":
                report_acc["one_page_summary"] = chunk.content or ""
            elif ct == "follow_up":
                if p.get("kind") != "section_rerun_notes":
                    items = p.get("items", [])
                    if not items and chunk.content:
                        items = [s.strip() for s in chunk.content.split("\n") if s.strip()]
                    report_acc["recommended_next_steps"] = items
            elif ct == "final_document":
                report_acc["final_document_md"] = chunk.content or ""
            elif ct == "reference":
                ref_p = {**p, "doc_id": chunk.doc_id}
                key = _ref_dedup_key(ref_p)
                references_acc[key] = {
                    "doc_id": chunk.doc_id,
                    "title": chunk.title,
                    "doc_number": chunk.doc_number,
                    "relevance_score": chunk.relevance_score,
                    "source_group": p.get("source_group"),
                    "source_label": p.get("source_label"),
                    "summary": p.get("summary"),
                    "profile_id": p.get("profile_id"),
                    "matched_fields": p.get("matched_fields"),
                    "matched_field_labels": p.get("matched_field_labels"),
                    "matched_chunks": p.get("matched_chunks") or [],
                }
            elif ct == "error":
                had_error = True
                await record_store.update_record(
                    record_id, user_id,
                    status="failed", last_error=chunk.content,
                )
                await run_store.update_run(
                    run_id, status="failed", error=chunk.content,
                    completed_at=_utcnow(),
                )
            elif ct == "done":
                # Capture citation_map from engine's done payload
                report_acc["citation_map"] = p.get("citation_map") or []
                # C1: only persist success if no error
                if not had_error:
                    await record_store.save_report(record_id, report_acc)
                    await record_store.update_record(
                        record_id, user_id,
                        status="completed",
                        summary=(report_acc["executive_summary"] or "")[:500],
                        references_json=list(references_acc.values()),
                        completed_at=_utcnow(),
                        last_error=None,  # H7: clear previous error on success
                    )
                    await run_store.update_run(
                        run_id, status="completed", completed_at=_utcnow(),
                    )

        # Set stream TTL
        await redis_client.expire(stream_key, _STREAM_TTL)

    except Exception:
        logger.exception("research_run_error run_id=%s", run_id)
        await record_store.update_record(
            record_id, user_id, status="failed", last_error="内部错误",
        )
        await run_store.update_run(
            run_id, status="failed", error="内部错误", completed_at=_utcnow(),
        )
    finally:
        # Clean up clients
        if mysql_client:
            await mysql_client.close()
        await redis_client.close()
        await es_client.raw.close()
        if neo4j_client:
            await neo4j_client.close()

    return {"run_id": run_id, "status": "done"}


# ---------------------------------------------------------------------------
# Section rerun
# ---------------------------------------------------------------------------

@celery_app.task(
    name="tasks.execute_section_rerun",
    bind=True,
    max_retries=0,
    acks_late=True,
    track_started=True,
)
def execute_section_rerun_task(
    self,
    run_id: str,
    record_id: str,
    user_id: str,
    acl_tokens: list[str],
    section_title: str,
    section_summary: str | None = None,
    source_doc_ids: list[str] | None = None,
) -> dict[str, Any]:
    logger.info("section_rerun_start run_id=%s section=%s", run_id, section_title)
    return _run_async(_execute_section_rerun(
        run_id, record_id, user_id, acl_tokens,
        section_title, section_summary, source_doc_ids,
    ))


async def _execute_section_rerun(
    run_id: str,
    record_id: str,
    user_id: str,
    acl_tokens: list[str],
    section_title: str,
    section_summary: str | None,
    source_doc_ids: list[str] | None,
) -> dict[str, Any]:
    """Worker-internal: execute section rerun and merge into existing report."""

    from app.config import settings
    from app.core.embedding import EmbeddingService
    from app.core.graph_query_service import GraphQueryService
    from app.core.permission import PermissionContext
    from app.core.research_engine import ResearchEngine
    from app.api.schemas.research import ResearchTask, ResearchPlan
    from app.infrastructure.embedding_client import EmbeddingClient
    from app.infrastructure.es_client import ESClient
    from app.infrastructure.llm_client import LLMClient
    from app.infrastructure.mysql_client import MySQLClient
    from app.infrastructure.neo4j_client import Neo4jClient
    from app.infrastructure.redis_client import RedisClient
    from app.infrastructure.research_record_store import (
        ResearchRecordRunStore,
        ResearchRecordStore,
    )
    from app.infrastructure.session_store import build_research_session_store
    from app.core.research_record_service import ResearchRecordService

    redis_client = RedisClient.from_settings()
    mysql_client = await MySQLClient.from_settings() if settings.mysql_enabled else None

    es_client = ESClient.from_settings(redis_client=redis_client)

    neo4j_client = Neo4jClient.from_settings()
    embedding_client = EmbeddingClient()
    llm_client = LLMClient()

    record_store = ResearchRecordStore(mysql_client)
    run_store = ResearchRecordRunStore(mysql_client)

    embedding_svc = EmbeddingService(embedding_client)
    graph_svc = GraphQueryService(neo4j_client) if neo4j_client else None
    session_store = build_research_session_store(
        redis_client=redis_client, mysql_client=mysql_client,
    )
    engine = ResearchEngine(
        es_client=es_client,
        embedding_service=embedding_svc,
        graph_service=graph_svc,
        llm_client=llm_client,
        session_store=session_store,
    )

    perm = PermissionContext(user_id=user_id, acl_tokens=acl_tokens)
    stream_key = f"{_STREAM_PREFIX}{run_id}:events"

    try:
        # C16 + C19: update states
        await run_store.update_run(run_id, status="running", started_at=_utcnow())
        await record_store.update_record(record_id, user_id, status="running")

        record = await record_store.get_record(record_id, user_id)
        if not record:
            raise ValueError(f"Record {record_id} not found")

        task = ResearchTask(**(record["task_json"] or {}))
        plan = ResearchPlan(**(record.get("plan_json") or {}))
        seed_doc_ids = ResearchRecordService.extract_seed_doc_ids(record) or None

        # Load existing report for merging
        existing_report_row = await record_store.get_report(record_id, user_id)
        existing_report = (
            existing_report_row.get("report_json", {})
            if existing_report_row else {}
        )

        new_section = None
        had_error = False
        # C20: collect references for merge
        new_references: dict[str, dict] = {}

        async for chunk in engine.rerun_section(
            task, plan, section_title, perm,
            section_summary=section_summary,
            source_doc_ids=source_doc_ids,
            seed_doc_ids=seed_doc_ids,
        ):
            payload = json.dumps(chunk.model_dump(exclude_none=True), ensure_ascii=False)
            await redis_client.xadd(stream_key, {"data": payload})

            ct = chunk.type
            p = chunk.payload or {}

            if ct == "section":
                new_section = {
                    "title": chunk.title or "",
                    "summary": p.get("summary", ""),
                    "content": chunk.content or "",
                    "source_doc_ids": p.get("source_doc_ids", []),
                }
            elif ct == "reference":
                ref_p = {**p, "doc_id": chunk.doc_id}
                key = _ref_dedup_key(ref_p)
                new_references[key] = {
                    "doc_id": chunk.doc_id,
                    "title": chunk.title,
                    "doc_number": chunk.doc_number,
                    "relevance_score": chunk.relevance_score,
                    "source_group": p.get("source_group"),
                    "source_label": p.get("source_label"),
                    "summary": p.get("summary"),
                    "profile_id": p.get("profile_id"),
                    "matched_fields": p.get("matched_fields"),
                    "matched_field_labels": p.get("matched_field_labels"),
                    "matched_chunks": p.get("matched_chunks") or [],
                }
            elif ct == "error":
                had_error = True
                # C19: restore to completed on failure, don't downgrade
                await record_store.update_record(
                    record_id, user_id,
                    status="completed",
                    last_error=chunk.content,
                )
                await run_store.update_run(
                    run_id, status="failed", error=chunk.content,
                    completed_at=_utcnow(),
                )
            elif ct == "done":
                if not had_error and new_section and existing_report:
                    # Replace matching section
                    old_sections = existing_report.get("sections", [])
                    updated = False
                    new_sections = []
                    for s in old_sections:
                        if s.get("title") == new_section["title"]:
                            new_sections.append(new_section)
                            updated = True
                        else:
                            new_sections.append(s)
                    if not updated:
                        new_sections.append(new_section)

                    updated_report = {**existing_report, "sections": new_sections}
                    await record_store.save_report(record_id, updated_report)

                    # C20: merge references
                    old_refs = record.get("references_json") or []
                    merged_refs: dict[str, dict] = {}
                    for ref in old_refs:
                        ref_p_local = {**ref}
                        key = _ref_dedup_key(ref_p_local)
                        merged_refs[key] = ref
                    merged_refs.update(new_references)

                    # C19: restore to completed
                    await record_store.update_record(
                        record_id, user_id,
                        status="completed",
                        references_json=list(merged_refs.values()),
                        last_error=None,  # H7: clear previous error
                    )
                    await record_store.touch_record(record_id, user_id)
                    await run_store.update_run(
                        run_id, status="completed", completed_at=_utcnow(),
                    )
                elif not had_error:
                    # C19: restore even if no new section produced
                    await record_store.update_record(
                        record_id, user_id, status="completed", last_error=None,
                    )
                    await run_store.update_run(
                        run_id, status="completed", completed_at=_utcnow(),
                    )

        await redis_client.expire(stream_key, _STREAM_TTL)

    except Exception:
        logger.exception("section_rerun_error run_id=%s", run_id)
        # C19: restore to completed, don't leave in running
        await record_store.update_record(
            record_id, user_id, status="completed", last_error="章节重跑内部错误",
        )
        await run_store.update_run(
            run_id, status="failed", error="内部错误", completed_at=_utcnow(),
        )
    finally:
        if mysql_client:
            await mysql_client.close()
        await redis_client.close()
        await es_client.raw.close()
        if neo4j_client:
            await neo4j_client.close()

    return {"run_id": run_id, "status": "done"}
