"""
入库追踪记录器 —— 为文档入库流程提供结构化事件时间线，用于全链路追踪与问题诊断。
Unified ingest trace recorder — structured event timeline for document ingestion.

Phase A: trace_id = task_id (Celery UUID). No independent artifact index yet.

设计要点:
  - 同一 trace 的多个 recorder 实例通过从 ES 同步 seq 保证事件序号单调递增
  - event_id 使用 UUID 避免跨实例冲突
  - 所有写入失败均被捕获并记录日志，绝不阻塞主入库流程

Usage
-----
    recorder = IngestTraceRecorder(es_client, trace_id=task_id, doc_id=doc_id)
    await recorder.start_trace(source_type="webhook", ...)
    await recorder.record("file_type_detected", "completed", summary="检测文件类型: pdf", details={...})
    await recorder.finish_trace("completed")

Multiple recorder instances for the same trace are safe: each one syncs
``seq`` from the trace summary in ES before its first write, and event IDs
use UUIDs instead of seq-based names to avoid collisions.
"""

from __future__ import annotations

import json
import time
import uuid
from datetime import datetime, timezone
from typing import Any

from app.config import settings
from app.utils.logger import get_logger

logger = get_logger(__name__)


class IngestTraceRecorder:
    """入库追踪记录器，为单次文档入库生成结构化事件时间线。

    Records structured ingest events for a single document ingest trace.

    Each instance tracks one trace (one ingest run for one document).
    Events are written to OpenSearch with monotonically increasing ``seq``.

    Design constraints (Phase A):
    - ``trace_id`` equals the Celery ``task_id``.
    - ``seq`` is synced from the trace summary in ES on first write, so
      multiple short-lived recorder instances (API → worker → pipeline)
      produce a monotonically increasing, collision-free sequence.
    - ``event_id`` uses UUID to guarantee uniqueness across instances.
    - Write failures are logged and swallowed — never crash the main pipeline.
    """

    def __init__(
        self,
        es_client: Any,
        *,
        trace_id: str,
        doc_id: str,
        task_id: str | None = None,
    ) -> None:
        self._es = es_client
        self.trace_id = trace_id
        self.doc_id = doc_id
        self.task_id = task_id or trace_id
        self._seq = 0
        self._seq_synced = False          # True after first sync from ES
        self._started_at: str | None = None
        self._current_stage: str | None = None
        self._stage_start_time: float | None = None
        self._processing_started_at: str | None = None  # set when worker begins
        self._has_stage_failure = False    # any non-blocking stage failed
        self._trace_defaults: dict[str, Any] = {
            "source_type": "unknown",
            "file_type": "",
            "original_filename": "",
            "title": "",
            "content_hash": "",
            "operator": "",
        }

    # ── Seq synchronisation ───────────────────────────────────────────────

    async def _sync_seq(self) -> bool:
        """Fetch ``latest_seq`` and ``started_at`` from the trace summary.

        Called once before the first ``record()`` in this instance so that
        a newly-created recorder continues the sequence produced by an
        earlier instance (e.g. API-layer recorder → worker recorder).

        仅用于首次初始化，后续 seq 递增通过 _atomic_inc_seq() 原子操作完成。
        """
        if self._seq_synced:
            return True
        try:
            resp = await self._es.raw.get(
                index=settings.es_trace_index,
                id=self.trace_id,
                params={"_source": "latest_seq,started_at,processing_started_at"},
            )
            raw = resp if isinstance(resp, dict) else resp.body
            src = raw.get("_source", {})
            self._seq = int(src.get("latest_seq", 0))
            if src.get("started_at") and not self._started_at:
                self._started_at = src["started_at"]
            if src.get("processing_started_at") and not self._processing_started_at:
                self._processing_started_at = src["processing_started_at"]
            self._seq_synced = True
            return True
        except Exception as exc:
            # Trace 可能尚未创建（首次写入），使用默认值即可
            # Trace may not exist yet (first write); use defaults
            logger.debug("sync_seq_failed", trace_id=self.trace_id, error=str(exc))
            return False

    async def _atomic_inc_seq(self) -> int:
        """通过 Painless 脚本原子递增 trace summary 的 latest_seq，解决多 recorder 并发 seq 冲突。

        Atomically increment latest_seq on the trace summary using a Painless
        script.  Returns the new seq value.  Uses retry_on_conflict=3 to handle
        concurrent updates from multiple recorder instances.
        """
        try:
            result = await self._es.raw.update(
                index=settings.es_trace_index,
                id=self.trace_id,
                body={
                    "script": {
                        "lang": "painless",
                        "source": "ctx._source.latest_seq = (ctx._source.latest_seq ?: 0) + 1",
                    },
                },
                params={"_source": "latest_seq", "retry_on_conflict": 3},
            )
            raw = result if isinstance(result, dict) else result.body
            # 优先从 get._source 读取（OpenSearch 标准响应）
            # Prefer get._source from standard OpenSearch response
            get_source = raw.get("get", {}).get("_source")
            if get_source and "latest_seq" in get_source:
                return int(get_source["latest_seq"])

            # 回退：如果响应不含 get._source（OpenSearch 版本差异），单独 GET 读取
            # Fallback: if response lacks get._source (OpenSearch version variance), read separately
            resp = await self._es.raw.get(
                index=settings.es_trace_index,
                id=self.trace_id,
                params={"_source": "latest_seq"},
            )
            fallback_raw = resp if isinstance(resp, dict) else resp.body
            return int(fallback_raw.get("_source", {}).get("latest_seq", 1))
        except Exception as exc:
            # 原子递增失败时回退到本地递增，保证不阻塞主流程
            # Fallback to local increment on failure — never block the pipeline
            logger.warning(
                "atomic_inc_seq_failed",
                trace_id=self.trace_id,
                error=str(exc),
            )
            self._seq += 1
            return self._seq

    def _build_trace_body(
        self,
        started_at: str,
        *,
        status: str = "running",
        current_stage: str = "upload_received",
    ) -> dict[str, Any]:
        """Build a full trace summary document for create/upsert paths."""
        return {
            "trace_id": self.trace_id,
            "doc_id": self.doc_id,
            "task_id": self.task_id,
            "content_hash": self._trace_defaults["content_hash"],
            "source_type": self._trace_defaults["source_type"],
            "status": status,
            "current_stage": current_stage,
            "file_type": self._trace_defaults["file_type"],
            "original_filename": self._trace_defaults["original_filename"],
            "title": self._trace_defaults["title"] or self._trace_defaults["original_filename"],
            "operator": self._trace_defaults["operator"],
            "attempt_count": 1,
            "latest_seq": 0,
            "started_at": started_at,
            "processing_started_at": None,
            "finished_at": None,
            # 初始值为 None 表示尚未完成; finish_trace 时计算实际耗时
            # Initial value None means not yet finished; actual duration computed in finish_trace
            "duration_ms": None,
            "error_code": None,
            "error_message": None,
            "artifact_count": 0,
            "created_at": started_at,
            "updated_at": started_at,
        }

    async def _ensure_trace_summary(self, current_stage: str) -> None:
        """Ensure the trace summary exists before recording events.

        This recovers from API-side trace creation failures by lazily upserting
        a minimal summary document from the worker-side recorder instance.
        """
        if await self._sync_seq():
            return

        now_iso = self._started_at or datetime.now(timezone.utc).isoformat()
        self._started_at = now_iso
        upsert_body = self._build_trace_body(
            now_iso,
            status="running",
            current_stage=current_stage,
        )
        await self._safe_update(
            settings.es_trace_index,
            self.trace_id,
            {
                "current_stage": current_stage,
                "updated_at": datetime.now(timezone.utc).isoformat(),
            },
            upsert=upsert_body,
        )
        await self._sync_seq()

    # ── Trace lifecycle ──────────────────────────────────────────────────

    async def start_trace(
        self,
        *,
        source_type: str = "webhook",
        file_type: str = "",
        original_filename: str = "",
        title: str = "",
        content_hash: str = "",
        operator: str = "",
    ) -> None:
        """Create or update the trace summary record."""
        now_iso = datetime.now(timezone.utc).isoformat()
        self._started_at = now_iso
        self._trace_defaults.update(
            {
                "source_type": source_type,
                "file_type": file_type,
                "original_filename": original_filename,
                "title": title,
                "content_hash": content_hash,
                "operator": operator,
            }
        )
        body = self._build_trace_body(now_iso)
        self._seq_synced = await self._safe_index(
            settings.es_trace_index,
            self.trace_id,
            body,
        )

    async def mark_processing_started(self) -> None:
        """Mark the moment the worker actually begins processing (excludes queue wait time)."""
        now_iso = datetime.now(timezone.utc).isoformat()
        self._processing_started_at = now_iso
        await self._safe_update(
            settings.es_trace_index,
            self.trace_id,
            {"processing_started_at": now_iso, "updated_at": now_iso},
        )

    async def finish_trace(
        self,
        status: str,
        *,
        error_code: str | None = None,
        error_message: str | None = None,
    ) -> None:
        """Update the trace summary with final status.

        ``started_at`` is read from ES if this recorder instance was not
        the one that called ``start_trace()``.
        """
        # Ensure we have started_at from ES even if this is a fresh instance
        await self._ensure_trace_summary(self._current_stage or status)

        now_iso = datetime.now(timezone.utc).isoformat()
        duration_ms = 0
        # 优先使用 processing_started_at（Worker 实际开始时间），排除队列等待耗时
        # Prefer processing_started_at (actual worker start) to exclude queue wait time
        ref_time = self._processing_started_at or self._started_at
        if ref_time:
            try:
                started = datetime.fromisoformat(ref_time)
                duration_ms = int(
                    (datetime.now(timezone.utc) - started).total_seconds() * 1000
                )
            except Exception:
                pass

        update_fields: dict[str, Any] = {
            "status": status,
            "finished_at": now_iso,
            "duration_ms": duration_ms,
            "latest_seq": self._seq,
            "updated_at": now_iso,
        }
        # Only overwrite error fields when explicitly provided.
        # For partial_failed the caller passes None so the per-stage
        # error already written by record() is preserved.
        if error_code is not None:
            update_fields["error_code"] = error_code
        if error_message is not None:
            update_fields["error_message"] = error_message[:2000]

        await self._safe_update(
            settings.es_trace_index,
            self.trace_id,
            update_fields,
        )

    async def update_trace_fields(self, fields: dict[str, Any]) -> None:
        """Partial update on the trace summary (e.g. content_hash, file_type)."""
        fields["updated_at"] = datetime.now(timezone.utc).isoformat()
        await self._safe_update(settings.es_trace_index, self.trace_id, fields)

    # ── Event recording ──────────────────────────────────────────────────

    async def record(
        self,
        stage: str,
        event_type: str,
        *,
        summary: str = "",
        details: dict[str, Any] | None = None,
        service: str = "worker",
        error_code: str | None = None,
        error_message: str | None = None,
        duration_ms: int | None = None,
        severity: str = "info",
        artifact_refs: list[str] | None = None,
    ) -> None:
        """向时间线追加一条结构化事件，自动递增 seq 并同步更新 trace 摘要。

        Append a structured event to the timeline."""
        # Sync seq from ES on first call to continue from previous recorder
        await self._ensure_trace_summary(stage)

        # 使用原子递增替代本地 self._seq += 1，解决多 recorder 并发 seq 冲突
        # Use atomic increment instead of local self._seq += 1 to fix concurrent seq collision
        new_seq = await self._atomic_inc_seq()
        self._seq = new_seq  # finish_trace 依赖 self._seq / finish_trace depends on self._seq
        now_iso = datetime.now(timezone.utc).isoformat()

        if event_type == "started":
            self._current_stage = stage
            self._stage_start_time = time.monotonic()
        elif event_type in ("completed", "failed") and duration_ms is None:
            if self._stage_start_time is not None and self._current_stage == stage:
                duration_ms = int(
                    (time.monotonic() - self._stage_start_time) * 1000
                )
                self._stage_start_time = None

        if error_code or error_message:
            severity = "error"

        if event_type == "failed":
            self._has_stage_failure = True

        # UUID-based event_id — 使用完整 32 字符 hex 确保全局唯一性
        # UUID-based event_id — use full 32-char hex for global uniqueness
        event_id = f"evt_{uuid.uuid4().hex}"
        body: dict[str, Any] = {
            "event_id": event_id,
            "trace_id": self.trace_id,
            "doc_id": self.doc_id,
            "task_id": self.task_id,
            "stage": stage,
            "event_type": event_type,
            "status": event_type,
            "seq": self._seq,
            "attempt": 1,
            "service": service,
            "operator": "",
            "summary": summary,
            "duration_ms": duration_ms or 0,
            "severity": severity,
            "error_code": error_code,
            "error_message": error_message[:2000] if error_message else None,
            "artifact_refs": artifact_refs or [],
            "details": details or {},
            "timestamp": now_iso,
            "created_at": now_iso,
        }

        await self._safe_index(settings.es_event_index, event_id, body)

        # Update trace summary with current stage
        trace_update: dict[str, Any] = {
            "latest_seq": self._seq,
            "updated_at": now_iso,
        }
        if event_type in ("started", "completed"):
            trace_update["current_stage"] = stage
        if error_code:
            trace_update["error_code"] = error_code
            trace_update["error_message"] = error_message
        await self._safe_update(settings.es_trace_index, self.trace_id, trace_update)

    # ── Convenience: paired started/completed ────────────────────────────

    async def record_stage_start(
        self,
        stage: str,
        *,
        summary: str = "",
        details: dict[str, Any] | None = None,
        service: str = "worker",
    ) -> None:
        """Record a stage started event."""
        await self.record(
            stage, "started", summary=summary, details=details, service=service
        )

    async def record_stage_complete(
        self,
        stage: str,
        *,
        summary: str = "",
        details: dict[str, Any] | None = None,
        service: str = "worker",
    ) -> None:
        """Record a stage completed event."""
        await self.record(
            stage, "completed", summary=summary, details=details, service=service
        )

    async def record_stage_failed(
        self,
        stage: str,
        *,
        summary: str = "",
        details: dict[str, Any] | None = None,
        error_code: str = "",
        error_message: str = "",
        service: str = "worker",
    ) -> None:
        """Record a stage failure event."""
        await self.record(
            stage,
            "failed",
            summary=summary,
            details=details,
            error_code=error_code,
            error_message=error_message,
            service=service,
            severity="error",
        )

    @property
    def has_stage_failure(self) -> bool:
        """Whether any non-blocking stage recorded a failure in this instance."""
        return self._has_stage_failure

    # ── Phase B: Artifact recording ──────────────────────────────────────

    async def record_artifact(
        self,
        stage: str,
        artifact_type: str,
        *,
        payload_json: dict[str, Any] | None = None,
        payload_text: str | None = None,
        preview_text: str = "",
        retention_level: str = "standard",
        is_redacted: bool = False,
    ) -> str:
        """写入一条产物记录（如解析结果、LLM 输出等），返回产物 ID，用于事件关联。

        Write an artifact record and return its ID."""
        artifact_id = f"art_{uuid.uuid4().hex[:12]}"
        now_iso = datetime.now(timezone.utc).isoformat()

        content_bytes = 0
        if payload_json:
            content_bytes = len(json.dumps(payload_json, ensure_ascii=False).encode())
        elif payload_text:
            content_bytes = len(payload_text.encode())

        body: dict[str, Any] = {
            "artifact_id": artifact_id,
            "trace_id": self.trace_id,
            "event_id": None,
            "doc_id": self.doc_id,
            "stage": stage,
            "artifact_type": artifact_type,
            "retention_level": retention_level,
            "is_redacted": is_redacted,
            "content_encoding": "json_inline" if payload_json else "text_inline",
            "preview_text": preview_text[:500] if preview_text else "",
            "payload_json": payload_json,
            "payload_text": payload_text[:32768] if payload_text else None,
            "storage_backend": "opensearch",
            "storage_path": None,
            "content_bytes": content_bytes,
            "expires_at": None,
            "created_at": now_iso,
        }

        await self._safe_index(settings.es_artifact_index, artifact_id, body)
        return artifact_id

    # ── Internal helpers ─────────────────────────────────────────────────

    async def _safe_index(
        self, index: str, doc_id: str, body: dict[str, Any]
    ) -> bool:
        """Index a document, swallowing errors to never block the pipeline."""
        try:
            await self._es.raw.index(index=index, id=doc_id, body=body)
            return True
        except Exception as exc:
            logger.warning(
                "trace_write_failed",
                index=index,
                doc_id=doc_id,
                error=str(exc),
            )
            return False

    async def _safe_update(
        self,
        index: str,
        doc_id: str,
        fields: dict[str, Any],
        *,
        upsert: dict[str, Any] | None = None,
    ) -> bool:
        """Partial update, swallowing errors."""
        try:
            body: dict[str, Any] = {"doc": fields}
            if upsert is not None:
                body["upsert"] = upsert
            await self._es.raw.update(
                index=index,
                id=doc_id,
                body=body,
            )
            return True
        except Exception as exc:
            logger.warning(
                "trace_update_failed",
                index=index,
                doc_id=doc_id,
                error=str(exc),
            )
            return False
