"""Celery tasks for document ingestion and permission updates.

Each task runs the async pipeline inside an event loop managed by
the Celery worker. The pipeline creates its own short-lived async clients
so it is safe to run in a separate process.

文档入库与权限更新的 Celery 异步任务模块。
主要包含两个任务：
1. ingest_document_task —— 执行完整的文档入库流水线（解析→分块→向量化→写入 ES/Neo4j）
2. update_permissions_task —— 更新文档 ACL 并重算所有共享 chunk 的权限
每个任务在独立事件循环中运行异步管线，支持失败自动重试。
"""

from __future__ import annotations

import asyncio
from typing import Any

# 显式导入 MaxRetriesExceededError，避免通过 self.MaxRetriesExceededError 访问
from celery.exceptions import MaxRetriesExceededError, Retry
from celery.utils.log import get_task_logger

from app.tasks.celery_app import celery_app

logger = get_task_logger(__name__)


def _run_async(coro):
    """Run an async coroutine in a new event loop (Celery worker compatible).

    Celery worker 进程中没有现成的事件循环，因此每次任务创建一个临时循环来执行异步协程。
    """
    loop = asyncio.new_event_loop()
    try:
        return loop.run_until_complete(coro)
    finally:
        loop.close()


@celery_app.task(
    name="tasks.ingest_document",
    bind=True,
    max_retries=3,
    default_retry_delay=30,
    acks_late=True,
    track_started=True,
)
def ingest_document_task(
    self,
    doc_id: str,
    file_path: str = "",
    metadata: dict[str, Any] | None = None,
) -> dict[str, Any]:
    """Celery task: run the full document ingest pipeline.

    This is the main entry point called when OA pushes a new document.
    It creates a fresh IngestPipeline, runs it, and handles retries.

    Args:
        doc_id: Unique document identifier from OA.
        file_path: Filename (not absolute path) of the document file.
            Resolved to full path via settings.file_storage_path.
        metadata: Document metadata dict.
    """
    metadata = metadata or {}
    # Phase A: trace_id = task_id (Celery UUID)
    trace_id = self.request.id
    attempt = self.request.retries + 1
    # 使用惰性格式化，避免在日志级别未启用时执行不必要的字符串拼接
    logger.info("Starting ingest for doc_id=%s, file=%s, attempt=%s", doc_id, file_path, attempt)

    # Update task state with progress
    self.update_state(
        state="PROCESSING",
        meta={"doc_id": doc_id, "step": "initializing", "progress": 0},
    )

    async def _run():
        from pathlib import Path

        import structlog
        from opensearchpy import AsyncOpenSearch

        from app.config import settings
        from app.core.ingest_pipeline import create_pipeline
        from app.core.ingest_trace_recorder import IngestTraceRecorder
        from app.infrastructure.es_client import ESClient

        # Bind trace context to structlog so all downstream logs carry these fields
        structlog.contextvars.clear_contextvars()
        structlog.contextvars.bind_contextvars(
            trace_id=trace_id, doc_id=doc_id, task_id=trace_id,
        )

        # Build a lightweight ES client for the recorder
        # 注意：recorder_es 只做 trace/event 写入，不走 ACL 路径，无需注入 RedisClient
        # Note: recorder_es only writes traces/events, not ACL paths — no RedisClient needed
        es_kwargs: dict[str, Any] = {"hosts": [settings.es_host], "timeout": 60}
        if settings.es_username:
            es_kwargs["http_auth"] = (settings.es_username, settings.es_password)
        recorder_es = ESClient(AsyncOpenSearch(**es_kwargs))

        recorder = IngestTraceRecorder(
            recorder_es, trace_id=trace_id, doc_id=doc_id, task_id=trace_id,
        )

        # 将整个 try/except 提升到 recorder 创建之后，确保任何异常都能
        # 调用 finish_trace，避免 trace 永远卡在 "running" 状态。
        # Move try/except right after recorder creation so that ANY exception
        # (including Redis/pipeline init failures) triggers finish_trace,
        # preventing traces stuck in "running" forever.
        pipeline = None
        _redis_client = None
        try:
            # 标记 Worker 实际开始处理的时间（排除队列等待）
            # Mark the actual processing start time (excludes queue wait)
            await recorder.mark_processing_started()
            # Record task_started
            await recorder.record(
                "task_started", "completed",
                summary=f"Worker 开始处理，第 {attempt} 次尝试",
                details={"attempt": attempt, "worker": self.request.hostname or ""},
                service="worker",
            )
            if attempt > 1:
                # 重试时将 trace 状态重置为 running，避免停留在上次的 failed 状态
                # Reset trace status to running on retry so it doesn't stay "failed"
                await recorder.update_trace_fields({
                    "attempt_count": attempt,
                    "status": "running",
                    "error_code": None,
                    "error_message": None,
                    "finished_at": None,
                    "duration_ms": None,
                })

            # Resolve portable filename via local storage dir
            p = Path(file_path)
            resolved = str(settings.file_storage_path / p.name) if not p.is_absolute() else file_path

            # 创建 RedisClient 用于 ACL 重算的分布式锁
            # Create RedisClient for distributed locking on ACL recompute paths
            from app.infrastructure.redis_client import RedisClient
            _redis_client = RedisClient.from_settings()

            # 创建 MySQLClient 用于转换日志记录（best-effort）
            # Create MySQLClient for conversion log recording (best-effort)
            _mysql_client = None
            if settings.mysql_enabled and settings.mysql_host and settings.mysql_database:
                try:
                    from app.infrastructure.mysql_client import MySQLClient
                    _mysql_client = await MySQLClient.from_settings()
                except Exception as exc:
                    logger.warning("celery_mysql_init_failed", error=str(exc))

            # ── Phase A: Snapshot old related_docs before pipeline overwrites meta ──
            # Distinguish "field absent" (skip sync) vs "field present but empty" (clear all)
            _has_related_docs_field = "related_docs" in metadata
            _related_docs_from_meta = metadata.get("related_docs", [])
            _old_related_docs: list[dict] = []
            if _has_related_docs_field:
                try:
                    _old_resp = await recorder_es.raw.get(
                        index=settings.es_meta_index, id=doc_id, _source=["related_docs"],
                    )
                    _old_raw = _old_resp if isinstance(_old_resp, dict) else _old_resp.body
                    _old_related_docs = _old_raw.get("_source", {}).get("related_docs", [])
                except Exception:
                    _old_related_docs = []  # New document doesn't exist yet

            pipeline = create_pipeline(redis_client=_redis_client, mysql_client=_mysql_client)
            result = await pipeline.ingest_document(
                doc_id=doc_id,
                file_path=resolved,
                metadata=metadata,
                trace_id=trace_id,
                recorder=recorder,
            )

            # ── Phase B: Best-effort bidirectional related docs sync ──────
            status = result.get("status", "failed")
            if status in ("completed", "partial_failed") and _has_related_docs_field:
                try:
                    from app.core.related_docs_service import RelatedDocsService
                    _rd_service = RelatedDocsService(recorder_es)
                    # Resolve title with same fallback as _write_doc_meta
                    _resolved_title = metadata.get("title", "")
                    if not _resolved_title:
                        _resolved_title = Path(metadata.get("original_filename", "")).stem
                    _sync_result = await _rd_service.sync(
                        doc_id=doc_id,
                        new_related=_related_docs_from_meta,
                        current_title=_resolved_title,
                        known_old_related=_old_related_docs,
                    )
                    if _sync_result.get("warnings"):
                        logger.warning(
                            "Related docs sync warnings for doc_id=%s: %s",
                            doc_id, _sync_result["warnings"],
                        )
                        result.setdefault("warnings", []).extend(
                            [{"type": "related_docs", **w} for w in _sync_result["warnings"]]
                        )
                except Exception as e:
                    logger.warning("Related docs sync failed for doc_id=%s: %s", doc_id, str(e))
                    result.setdefault("warnings", []).append({
                        "type": "related_docs", "doc_id": "", "code": "SYNC_EXCEPTION",
                        "reason": f"关联同步异常: {e}",
                    })

            # Finish trace based on result status
            if status in ("completed", "partial_failed"):
                # For partial_failed: do NOT overwrite error_code/error_message —
                # the specific stage errors (e.g. INGEST_GRAPH_LLM_FAILED) are
                # already written to the trace summary by record().  Passing None
                # tells finish_trace to leave those fields untouched.
                await recorder.finish_trace(status)
            else:
                await recorder.finish_trace(
                    "failed",
                    error_code=result.get("error_code", "INGEST_PIPELINE_FAILED"),
                    error_message=result.get("error", ""),
                )

            return result
        except Exception as exc:
            await recorder.finish_trace(
                "failed",
                error_code="INGEST_UNEXPECTED_ERROR",
                error_message=str(exc),
            )
            raise
        finally:
            # 通过公共 close() 方法清理管线持有的异步客户端，避免访问私有属性
            if pipeline is not None:
                await pipeline.close()
            await recorder_es.close()
            # 关闭为 ACL 分布式锁创建的 RedisClient
            if _redis_client is not None:
                await _redis_client.close()
            # 关闭为转换日志创建的 MySQLClient
            if _mysql_client is not None:
                await _mysql_client.close()

    try:
        result = _run_async(_run())

        if result["status"] == "failed":
            error_msg = result.get("error", "Unknown error")
            logger.error("Ingest failed for doc_id=%s: %s", doc_id, error_msg)

            # Retry on transient errors
            if self.request.retries < self.max_retries:
                raise self.retry(
                    exc=RuntimeError(error_msg),
                    countdown=30 * (self.request.retries + 1),
                )

        return result

    except (Retry, MaxRetriesExceededError):
        # Let Celery handle retry / max-retries-exceeded natively
        raise
    except Exception as exc:
        logger.error("Unexpected error in ingest for doc_id=%s: %s", doc_id, exc)
        if self.request.retries < self.max_retries:
            raise self.retry(exc=exc, countdown=30 * (self.request.retries + 1))
        return {
            "doc_id": doc_id,
            "status": "failed",
            "error": str(exc),
        }


@celery_app.task(
    name="tasks.update_permissions",
    bind=True,
    max_retries=3,
    default_retry_delay=10,
)
def update_permissions_task(
    self,
    doc_id: str,
    acl_ids: list[str] | None = None,
) -> dict[str, Any]:
    """Celery task: update a document's ACL and recompute chunk permissions.

    Called when OA pushes a permission change.  Strategy:
      1. Update the meta record's acl_ids (source of truth).
      2. Recompute the chunks' acl_ids from all metas sharing the content_hash.
    """
    logger.info("Updating permissions for doc_id=%s, acl_count=%s", doc_id, len(acl_ids or []))

    if acl_ids is None:
        return {"doc_id": doc_id, "status": "skipped", "reason": "acl_ids missing"}

    async def _run():
        from opensearchpy import AsyncOpenSearch
        from datetime import datetime, timezone

        from app.config import settings
        from app.infrastructure.es_client import ESClient
        from app.infrastructure.redis_client import RedisClient

        # 创建 RedisClient 用于 ACL 重算的分布式锁
        # Create RedisClient for distributed locking on ACL recompute paths
        redis_client = RedisClient.from_settings()

        es_kwargs: dict[str, Any] = {"hosts": [settings.es_host], "timeout": 60}
        if settings.es_username:
            es_kwargs["http_auth"] = (settings.es_username, settings.es_password)

        raw_es = AsyncOpenSearch(**es_kwargs)
        # 注入 RedisClient 到 ESClient，用于 recompute_chunk_acl 的分布式锁
        # Inject RedisClient into ESClient for recompute_chunk_acl distributed locking
        es_client = ESClient(raw_es, redis_client=redis_client)
        try:
            # Step 1: Update meta record's acl_ids (the source of truth)
            await raw_es.update(
                index=settings.es_meta_index,
                id=doc_id,
                body={
                    "doc": {
                        "acl_ids": acl_ids,
                        "updated_at": datetime.now(timezone.utc).isoformat(),
                    }
                },
                refresh=True,
            )

            # Step 2: Get the content_hash from meta
            meta_resp = await raw_es.get(
                index=settings.es_meta_index,
                id=doc_id,
                _source=["content_hash"],
            )
            raw = meta_resp if isinstance(meta_resp, dict) else meta_resp.body
            content_hash = raw.get("_source", {}).get("content_hash", "")

            if not content_hash:
                return {"doc_id": doc_id, "status": "completed", "note": "no content_hash"}

            # Step 3: Recompute chunks' ACL from all metas (source of truth)
            await es_client.recompute_chunk_acl(content_hash)
            guide_updates = await es_client.sync_service_guide_acl(doc_id, acl_ids)

            return {
                "doc_id": doc_id,
                "status": "completed",
                "content_hash": content_hash,
                "guide_updates": guide_updates,
            }

        finally:
            await raw_es.close()
            # 关闭 RedisClient，释放连接
            await redis_client.close()

    try:
        return _run_async(_run())
    except Exception as exc:
        logger.error("Permission update failed for doc_id=%s: %s", doc_id, exc)
        raise self.retry(exc=exc)
