"""Research record lifecycle service.

编排层：位于 records API 与 ResearchEngine 之间。
负责记录创建、计划生成、运行启动（Celery 投递）、SSE 订阅。
"""

from __future__ import annotations

import uuid
from datetime import datetime, timezone
from typing import Any, AsyncIterator

from app.api.schemas.research import ResearchChunk, ResearchTask
from app.core.research_engine import ResearchEngine
from app.infrastructure.redis_client import RedisClient
from app.infrastructure.research_record_store import (
    ResearchRecordRunStore,
    ResearchRecordStore,
)
from app.utils.logger import get_logger

logger = get_logger(__name__)

_STREAM_KEY_PREFIX = "research:run:"


def _utcnow() -> datetime:
    return datetime.now(timezone.utc)


def _new_id() -> str:
    return uuid.uuid4().hex


class ResearchRecordService:
    """Orchestrates research record lifecycle."""

    def __init__(
        self,
        record_store: ResearchRecordStore,
        run_store: ResearchRecordRunStore,
        redis: RedisClient,
        engine: ResearchEngine,
    ) -> None:
        self._store = record_store
        self._run_store = run_store
        self._redis = redis
        self._engine = engine

    # ------------------------------------------------------------------
    # Record CRUD
    # ------------------------------------------------------------------

    async def create_record(
        self,
        user_id: str,
        title: str,
        mode: str,
        output_template: str,
        task: dict[str, Any],
        imported_items: list[dict] | None = None,
    ) -> str:
        record_id = _new_id()
        await self._store.create_record(
            record_id=record_id,
            user_id=user_id,
            title=title,
            mode=mode,
            output_template=output_template,
            task_json=task,
            imported_items_json=imported_items,
        )
        logger.info("record_created", record_id=record_id, user_id=user_id)
        return record_id

    # ------------------------------------------------------------------
    # Regenerate → new version
    # ------------------------------------------------------------------

    async def regenerate_record(
        self,
        record_id: str,
        user_id: str,
        task: dict[str, Any],
    ) -> dict[str, Any]:
        """Create a new version of an existing record.

        Returns ``{"id": ..., "title": ..., "version_no": ...}``.
        """
        original = await self._store.get_record(record_id, user_id)
        if original is None:
            raise ValueError(f"Record {record_id} not found")
        if original["status"] not in ("planned", "completed", "failed"):
            raise ValueError(
                f"Cannot regenerate record in status {original['status']}"
            )

        # Resolve root_record_id
        root_id = original["root_record_id"] or original["id"]
        # Backfill root on original if it was NULL
        if not original["root_record_id"]:
            await self._store.update_record(
                original["id"], user_id, root_record_id=original["id"],
            )

        # Concurrent-safe version number
        row = await self._store._db.fetch_one(
            "SELECT COALESCE(MAX(version_no), 0) + 1 AS next_ver "
            "FROM research_records WHERE root_record_id = %s",
            (root_id,),
        )
        next_ver = int(row["next_ver"]) if row else 2

        # mode / output_template from submitted task (I1: user may have changed)
        mode = task.get("mode", original["mode"])
        output_template = task.get("output_template", original["output_template"])

        new_id = _new_id()
        try:
            await self._store.create_record(
                record_id=new_id,
                user_id=user_id,
                title=original["title"],
                mode=mode,
                output_template=output_template,
                task_json=task,
                parent_record_id=record_id,
                root_record_id=root_id,
                version_no=next_ver,
            )
        except Exception as exc:
            # UNIQUE KEY uq_root_version violated → concurrent regenerate
            if "Duplicate entry" in str(exc) or "1062" in str(exc):
                raise ValueError("正在创建新版本，请勿重复提交") from exc
            raise

        logger.info(
            "record_regenerated",
            new_id=new_id, parent_id=record_id, version_no=next_ver,
        )
        return {"id": new_id, "title": original["title"], "version_no": next_ver}

    # ------------------------------------------------------------------
    # Seed doc extraction  (C2)
    # ------------------------------------------------------------------

    @staticmethod
    def extract_seed_doc_ids(record: dict[str, Any]) -> list[str]:
        """Extract seed_doc_ids from imported_items_json + task.required_doc_ids.

        Field name is ``item_type`` (see basket.d.ts).
        - document / snippet → doc_id
        - matter → governing_doc_ids
        - answer → references[].doc_id
        """
        doc_ids: set[str] = set()

        # From task required_doc_ids
        task_json = record.get("task_json") or {}
        for did in task_json.get("required_doc_ids") or []:
            if did:
                doc_ids.add(did)

        # From imported items
        for item in record.get("imported_items_json") or []:
            item_type = item.get("item_type", "")
            if item_type in ("document", "snippet"):
                did = item.get("doc_id")
                if did:
                    doc_ids.add(did)
            elif item_type == "matter":
                for did in item.get("governing_doc_ids") or []:
                    if did:
                        doc_ids.add(did)
            elif item_type == "answer":
                for ref in item.get("references") or []:
                    did = ref.get("doc_id") if isinstance(ref, dict) else None
                    if did:
                        doc_ids.add(did)

        return list(doc_ids)

    # ------------------------------------------------------------------
    # Pre-plan clarification (F1) — direct stream, no Celery
    # ------------------------------------------------------------------

    async def clarify(
        self,
        record_id: str,
        user_id: str,
        messages: list[dict[str, str]],
        *,
        enable_kb_search: bool = True,
        acl_tokens: list[str] | None = None,
    ) -> AsyncIterator[dict[str, Any]]:
        """Run clarification LLM call and yield SSE chunks.

        Direct stream mode: no Celery, no Redis Stream.
        Persists clarification_messages_json each round.
        When ready, auto-updates task_json with merged task_patch.

        If ``enable_kb_search`` is True, pre-searches the knowledge base
        to assess data coverage before LLM evaluation.
        """
        from app.prompts.research_prompts import (
            RESEARCH_CLARIFY_SYSTEM,
            RESEARCH_CLARIFY_USER,
        )

        record = await self._store.get_record(record_id, user_id)
        if record is None:
            raise ValueError(f"Record {record_id} not found")
        if record["status"] not in ("draft", "clarifying"):
            raise ValueError(f"Cannot clarify record in status {record['status']}")

        task_json = record.get("task_json") or {}
        mode = task_json.get("mode", "deep")
        max_questions = 1 if mode == "quick" else 3

        # Build history from existing clarification messages + new messages
        existing_msgs = record.get("clarification_messages_json") or []
        all_msgs = existing_msgs + messages
        history_text = "\n".join(
            f"{'用户' if m.get('role') == 'user' else '助手'}: {m.get('content', '')}"
            for m in all_msgs
        ) or "（无历史对话）"

        # Build user materials context from imported items
        imported_items = record.get("imported_items_json") or []
        if imported_items:
            materials_lines = []
            uploaded_texts = []
            for item in imported_items[:10]:
                item_type = item.get("item_type", "unknown")
                title = item.get("title", "未知")
                materials_lines.append(f"- [{item_type}] {title}")
                # Include extracted text from uploaded files
                if item_type == "uploaded_file" and item.get("extracted_text"):
                    text = item["extracted_text"]
                    if len(text) > 5000:
                        text = text[:5000] + "...(已截断)"
                    uploaded_texts.append(f"【文件：{title}】\n{text}")
            user_materials = f"用户已提供 {len(imported_items)} 项资料：\n" + "\n".join(materials_lines)
            if uploaded_texts:
                user_materials += "\n\n--- 上传文件内容 ---\n" + "\n\n".join(uploaded_texts)
        else:
            user_materials = "用户未提供任何资料。"

        # KB pre-search (optional)
        kb_search_results = "（未开启知识库检索）"
        if enable_kb_search:
            yield {"type": "clarification_thinking", "content": "正在检索知识库..."}
            try:
                query = f"{task_json.get('topic', '')} {task_json.get('question', '')}"
                from app.core.permission import PermissionContext
                perm = PermissionContext(user_id=user_id, acl_tokens=acl_tokens or [])
                es_docs = await self._engine._es_search(query.strip(), perm)
                if es_docs:
                    kb_lines = []
                    for doc in es_docs[:8]:
                        doc_title = doc.get("title", "未知")
                        doc_id = doc.get("doc_id", "")
                        kb_lines.append(f"- {doc_title} (ID: {doc_id})")
                    kb_search_results = (
                        f"知识库共检索到 {len(es_docs)} 篇相关文档"
                        f"（展示前 {min(len(es_docs), 8)} 篇）：\n"
                        + "\n".join(kb_lines)
                    )
                else:
                    kb_search_results = "知识库未检索到相关文档，可能需要用户补充资料或调整研究范围。"
            except Exception as exc:
                logger.warning("clarify_kb_search_failed", error=str(exc))
                kb_search_results = "知识库检索失败，将仅基于用户提供的资料进行评估。"

        yield {"type": "clarification_thinking", "content": "正在综合评估研究任务..."}

        system_prompt = RESEARCH_CLARIFY_SYSTEM.format(max_questions=max_questions)
        user_prompt = RESEARCH_CLARIFY_USER.format(
            topic=task_json.get("topic", ""),
            question=task_json.get("question", ""),
            goal=task_json.get("goal", ""),
            mode=mode,
            output_template=task_json.get("output_template", "comprehensive"),
            depth_level=task_json.get("depth_level", "deep"),
            user_materials=user_materials,
            kb_search_results=kb_search_results,
            history=history_text,
        )

        # Update status to clarifying
        await self._store.update_record(record_id, user_id, status="clarifying")

        # LLM call (non-streaming JSON mode)
        import json as _json
        try:
            result = await self._engine._llm.chat_json(
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt},
                ],
                temperature=0.3,
            )
        except Exception:
            result = {"ready": False, "questions": ["请提供更详细的研究主题描述。"], "note": "解析失败，请重试。"}

        is_ready = result.get("ready", False)
        questions = result.get("questions", [])
        task_patch = result.get("task_patch", {})
        assessment = result.get("assessment", {})
        note = result.get("note", "")

        # Emit assessment for frontend display
        if assessment:
            yield {"type": "assessment", "content": _json.dumps(assessment, ensure_ascii=False)}

        # Persist conversation
        assistant_msg = {"role": "assistant", "content": _json.dumps(result, ensure_ascii=False)}
        updated_msgs = all_msgs + [assistant_msg]
        await self._store.update_record(
            record_id, user_id,
            clarification_messages_json=updated_msgs,
        )

        if is_ready:
            # Merge task_patch into task_json and persist
            if task_patch:
                merged_task = {**task_json, **task_patch}
                await self._store.update_record(
                    record_id, user_id,
                    task_json=merged_task,
                    status="draft",
                )
                yield {"type": "task_patch", "content": _json.dumps(task_patch, ensure_ascii=False)}
            else:
                await self._store.update_record(record_id, user_id, status="draft")

            if note:
                yield {"type": "clarification_note", "content": note}
            yield {"type": "ready_signal", "content": "任务信息充分，可以生成研究计划。"}
        else:
            for q in questions:
                yield {"type": "clarification_question", "content": q}
            if note:
                yield {"type": "clarification_note", "content": note}

        yield {"type": "done", "content": ""}

    # ------------------------------------------------------------------
    # Post-research chat (F2) — direct stream, no Celery
    # ------------------------------------------------------------------

    async def chat(
        self,
        record_id: str,
        user_id: str,
        user_message: str,
    ) -> AsyncIterator[dict[str, Any]]:
        """Run post-research chat and yield SSE chunks.

        Persists chat_messages_json after each round.
        """
        from app.prompts.research_prompts import (
            RESEARCH_POST_CHAT_SYSTEM,
            RESEARCH_POST_CHAT_USER,
        )

        record = await self._store.get_record(record_id, user_id)
        if record is None:
            raise ValueError(f"Record {record_id} not found")
        if record["status"] != "completed":
            raise ValueError("Chat is only available for completed records")

        # Build truncated research context (C5: limit to ~2000 tokens)
        task_json = record.get("task_json") or {}
        plan_json = record.get("plan_json") or {}
        refs = record.get("references_json") or []

        # Load report summary
        report_data = await self._store.get_report(record_id, user_id)
        report_json = report_data.get("report_json", {}) if report_data else {}

        context_parts = []
        context_parts.append(f"研究主题: {task_json.get('topic', '')}")
        context_parts.append(f"核心问题: {task_json.get('question', '')}")
        if plan_json.get("summary"):
            context_parts.append(f"计划摘要: {plan_json['summary']}")
        if report_json.get("executive_summary"):
            context_parts.append(f"执行摘要: {report_json['executive_summary'][:500]}")
        sections = report_json.get("sections", [])
        if sections:
            titles = [s.get("title", "") for s in sections[:8]]
            context_parts.append(f"章节标题: {', '.join(titles)}")
        # Top 5 references (title only)
        for ref in refs[:5]:
            if isinstance(ref, dict):
                context_parts.append(f"引用: {ref.get('title', '')}")
        research_context = "\n".join(context_parts)

        # Chat history
        existing_msgs = record.get("chat_messages_json") or []
        history_text = "\n".join(
            f"{'用户' if m.get('role') == 'user' else '助手'}: {m.get('content', '')[:200]}"
            for m in existing_msgs[-10:]  # Last 10 messages
        ) or "（无历史对话）"

        system_prompt = RESEARCH_POST_CHAT_SYSTEM.format(
            research_context=research_context,
            history=history_text,
        )
        user_prompt = RESEARCH_POST_CHAT_USER.format(message=user_message)

        yield {"type": "chat_thinking", "content": "正在思考..."}

        # LLM streaming call (llm.chat() is AsyncIterator[str])
        import json as _json

        full_response = ""
        async for chunk in self._engine._llm.chat(
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ],
            temperature=0.5,
        ):
            full_response += chunk
            yield {"type": "chat_content", "content": chunk}

        # Check for suggest_regenerate marker
        if "[SUGGEST_REGENERATE]" in full_response:
            parts = full_response.split("[SUGGEST_REGENERATE]", 1)
            reason = parts[1].strip() if len(parts) > 1 else "建议基于新参数重新研究"
            yield {"type": "suggest_regenerate", "content": reason}

        # Persist conversation (D2: every round)
        new_msgs = existing_msgs + [
            {"role": "user", "content": user_message},
            {"role": "assistant", "content": full_response},
        ]
        await self._store.update_record(
            record_id, user_id,
            chat_messages_json=new_msgs,
        )

        yield {"type": "done", "content": ""}

    # ------------------------------------------------------------------
    # Plan generation
    # ------------------------------------------------------------------

    async def generate_plan(self, record_id: str, user_id: str) -> dict[str, Any]:
        """Generate plan and update record to 'planned' status.

        build_plan() does NOT need perm (research_engine.py:697).
        """
        record = await self._store.get_record(record_id, user_id)
        if record is None:
            raise ValueError(f"Record {record_id} not found")
        if record["status"] == "clarifying":
            raise ValueError("Record is in clarifying state, cannot generate plan")

        task_json = record.get("task_json")
        if not task_json:
            raise ValueError("Record has no task_json")

        task = ResearchTask(**task_json)
        seed_doc_ids = self.extract_seed_doc_ids(record)

        plan = await self._engine.build_plan(task, seed_doc_ids=seed_doc_ids or None)
        plan_dict = plan.model_dump()

        await self._store.update_record(
            record_id, user_id,
            plan_json=plan_dict,
            status="planned",
        )
        logger.info("plan_generated", record_id=record_id)
        return plan_dict

    # ------------------------------------------------------------------
    # Run lifecycle  (Celery delegation)
    # ------------------------------------------------------------------

    async def start_run(
        self,
        record_id: str,
        user_id: str,
        acl_tokens: list[str],
    ) -> str:
        """Create a run record and dispatch Celery task. Returns run_id."""
        record = await self._store.get_record(record_id, user_id)
        if record is None:
            raise ValueError(f"Record {record_id} not found")
        if record["status"] not in ("planned", "completed", "failed"):
            raise ValueError(f"Cannot run record in status {record['status']}")

        # H4: prevent duplicate concurrent runs
        latest_run = await self._run_store.get_latest_run(record_id, user_id)
        if latest_run and latest_run["status"] in ("pending", "running"):
            raise ValueError("A run is already in progress for this record")

        run_id = _new_id()
        await self._run_store.create_run(
            run_id=run_id,
            record_id=record_id,
            user_id=user_id,
            run_type="full",
        )
        # Push record to "running" before dispatch so refresh can recover (H1)
        await self._store.update_record(record_id, user_id, status="running")

        try:
            from app.tasks.research_task import execute_research_run_task

            execute_research_run_task.delay(
                run_id=run_id,
                record_id=record_id,
                user_id=user_id,
                acl_tokens=acl_tokens,
            )
        except Exception as exc:
            # C17: full run queue failure → record.status = failed
            logger.exception("celery_dispatch_failed", run_id=run_id)
            await self._run_store.update_run(
                run_id, status="failed", error=str(exc),
                completed_at=_utcnow(),
            )
            await self._store.update_record(
                record_id, user_id,
                status="failed",
                last_error=f"任务投递失败: {exc}",
            )
            raise

        logger.info("run_started", run_id=run_id, record_id=record_id)
        return run_id

    async def start_section_rerun(
        self,
        record_id: str,
        user_id: str,
        acl_tokens: list[str],
        section_title: str,
        section_summary: str | None = None,
        source_doc_ids: list[str] | None = None,
    ) -> str:
        """Create a section rerun run and dispatch Celery task."""
        record = await self._store.get_record(record_id, user_id)
        if record is None:
            raise ValueError(f"Record {record_id} not found")
        # H2: only completed records with a report can be rerun
        if record["status"] != "completed":
            raise ValueError(f"Cannot rerun section on record in status {record['status']}")
        # P1-6: verify report actually exists before allowing rerun
        existing_report = await self._store.get_report(record_id, user_id)
        if existing_report is None:
            raise ValueError("Cannot rerun section: no report found for this record")
        # H4: prevent duplicate concurrent runs
        latest_run = await self._run_store.get_latest_run(record_id, user_id)
        if latest_run and latest_run["status"] in ("pending", "running"):
            raise ValueError("A run is already in progress for this record")

        run_id = _new_id()
        await self._run_store.create_run(
            run_id=run_id,
            record_id=record_id,
            user_id=user_id,
            run_type="section_rerun",
            section_title=section_title,
        )
        # Push status before dispatch (H1)
        await self._store.update_record(record_id, user_id, status="running")

        try:
            from app.tasks.research_task import execute_section_rerun_task

            execute_section_rerun_task.delay(
                run_id=run_id,
                record_id=record_id,
                user_id=user_id,
                acl_tokens=acl_tokens,
                section_title=section_title,
                section_summary=section_summary,
                source_doc_ids=source_doc_ids,
            )
        except Exception as exc:
            # C17: section rerun queue failure → don't change record.status
            logger.exception("celery_dispatch_failed", run_id=run_id)
            await self._run_store.update_run(
                run_id, status="failed", error=str(exc),
                completed_at=_utcnow(),
            )
            await self._store.update_record(
                record_id, user_id,
                last_error=f"章节重跑任务投递失败: {exc}",
            )
            raise

        logger.info("section_rerun_started", run_id=run_id, record_id=record_id)
        return run_id

    # ------------------------------------------------------------------
    # SSE subscription
    # ------------------------------------------------------------------

    async def subscribe_run(
        self,
        record_id: str,
        run_id: str,
        user_id: str,
    ) -> AsyncIterator[ResearchChunk]:
        """Subscribe to a run's Redis Stream events.

        Validates run_id → record_id → user_id ownership (C6).
        Phase 1: reads from stream head, no cursor resume.
        """
        run = await self._run_store.get_run_by_record(run_id, record_id, user_id)
        if run is None:
            raise ValueError(f"Run {run_id} not found or not owned by user")

        stream_key = f"{_STREAM_KEY_PREFIX}{run_id}:events"
        last_id = "0"

        while True:
            entries = await self._redis.xread(
                {stream_key: last_id}, count=10, block=2000
            )
            if not entries:
                # Check if run has already terminated
                current_run = await self._run_store.get_run(run_id, user_id)
                if current_run and current_run["status"] in ("completed", "failed"):
                    break
                continue

            for _, messages in entries:
                for msg_id, fields in messages:
                    last_id = msg_id
                    # C10: decode_responses=True → fields keys are strings
                    raw_data = fields.get("data", "{}")
                    chunk = ResearchChunk.model_validate_json(raw_data)
                    yield chunk
                    if chunk.type == "done":
                        return
