"""Notebook service — orchestrates sources, chat, and output generation.

Notebook 业务编排层。负责来源管理（搜索/上传/粘贴）、范围隔离对话、输出文档生成。
"""

from __future__ import annotations

import json
import uuid
from typing import Any, AsyncIterator

from app.api.schemas.notebook import NotebookConfig
from app.api.schemas.research import ResearchChunk
from app.config import settings
from app.core.embedding import EmbeddingService
from app.core.graph_query_service import GraphQueryService
from app.core.permission import PermissionContext
from app.core.research_engine import ResearchEngine
from app.infrastructure.es_client import ESClient
from app.infrastructure.llm_client import LLMClient
from app.infrastructure.notebook_store import NotebookStore
from app.infrastructure.session_store import ResearchSessionStore
from app.prompts.notebook_prompts import (
    OUTPUT_FORMAT_TEMPLATES,
    OUTPUT_GENERATION_USER,
    OUTPUT_TYPE_NAMES,
    SOURCE_SUMMARY_SYSTEM,
    SOURCE_SUMMARY_USER,
    SUGGEST_QUESTIONS_SYSTEM,
    SUGGEST_QUESTIONS_USER,
    build_notebook_chat_system,
)
from app.utils.logger import get_logger

logger = get_logger(__name__)

_MAX_SOURCES = 10


def _gen_id(prefix: str = "") -> str:
    return f"{prefix}{uuid.uuid4().hex[:16]}"


class NotebookService:
    """Business logic layer for the Notebook module."""

    def __init__(
        self,
        store: NotebookStore,
        es_client: ESClient,
        embedding_service: EmbeddingService,
        graph_service: GraphQueryService,
        llm_client: LLMClient,
        session_store: ResearchSessionStore | None = None,
    ) -> None:
        self._store = store
        self._es = es_client
        self._embedding = embedding_service
        self._graph = graph_service
        self._llm = llm_client
        self._session_store = session_store

    # ==================================================================
    # Notebook CRUD
    # ==================================================================

    async def create_notebook(self, user_id: str, title: str, description: str = "") -> dict[str, Any]:
        notebook_id = _gen_id("nb_")
        await self._store.create_notebook(notebook_id, user_id, title, description)
        return {"id": notebook_id, "title": title}

    async def get_notebook(self, notebook_id: str, user_id: str) -> dict[str, Any] | None:
        nb = await self._store.get_notebook(notebook_id, user_id)
        if nb is None:
            return None
        # Normalize field name: config_json -> config (match schema)
        if "config_json" in nb:
            nb["config"] = nb.pop("config_json")
        sources = await self._store.list_sources(notebook_id, user_id)
        nb["sources"] = sources
        return nb

    async def list_notebooks(
        self, user_id: str, *, q: str | None = None, status: str | None = None,
        limit: int = 50, offset: int = 0,
    ) -> tuple[list[dict[str, Any]], int]:
        notebooks, total = await self._store.list_notebooks(
            user_id, q=q, status=status, limit=limit, offset=offset,
        )
        # Attach source and message counts
        for nb in notebooks:
            nb["source_count"] = await self._store.count_sources(nb["id"], user_id)
            nb["message_count"] = await self._store.count_messages(nb["id"], user_id)
        return notebooks, total

    async def update_notebook(
        self, notebook_id: str, user_id: str, **fields: Any,
    ) -> int:
        update_fields: dict[str, Any] = {}
        if "title" in fields and fields["title"] is not None:
            update_fields["title"] = fields["title"]
        if "description" in fields and fields["description"] is not None:
            update_fields["description"] = fields["description"]
        if "config" in fields and fields["config"] is not None:
            config = fields["config"]
            if isinstance(config, NotebookConfig):
                update_fields["config_json"] = config.model_dump()
            elif isinstance(config, dict):
                update_fields["config_json"] = config
        return await self._store.update_notebook(notebook_id, user_id, **update_fields)

    async def archive_notebook(self, notebook_id: str, user_id: str) -> int:
        return await self._store.archive_notebook(notebook_id, user_id)

    async def delete_notebook(self, notebook_id: str, user_id: str) -> int:
        # Get sources to cleanup ES chunks for uploaded/pasted sources
        sources = await self._store.list_sources(notebook_id, user_id)
        affected = await self._store.delete_notebook(notebook_id, user_id)
        if affected > 0:
            for src in sources:
                if src["source_type"] in ("upload", "paste") and src.get("doc_id"):
                    try:
                        await self._cleanup_es_source(src["doc_id"])
                    except Exception as exc:
                        logger.warning(
                            "notebook_es_cleanup_failed",
                            doc_id=src.get("doc_id"),
                            error=str(exc),
                        )
        return affected

    # ==================================================================
    # Sources
    # ==================================================================

    async def add_source_by_search(
        self, notebook_id: str, user_id: str, doc_id: str, title: str = "",
    ) -> dict[str, Any]:
        """Add an existing document from the global index as a source."""
        count = await self._store.count_sources(notebook_id, user_id)
        if count >= _MAX_SOURCES:
            raise ValueError(f"来源数量已达上限 ({_MAX_SOURCES})")

        source_id = _gen_id("src_")
        # Verify doc exists in ES
        if not title:
            try:
                doc = await self._es.raw.get(
                    index=settings.es_meta_index, id=doc_id, _source=["title"],
                )
                raw = doc if isinstance(doc, dict) else doc.body
                title = raw.get("_source", {}).get("title", doc_id)
            except Exception:
                title = doc_id

        await self._store.create_source(
            source_id, notebook_id, user_id, "search", title,
            doc_id=doc_id, ingest_status="completed",
        )
        return {"id": source_id, "doc_id": doc_id, "ingest_status": "completed"}

    async def add_source_by_upload(
        self,
        notebook_id: str,
        user_id: str,
        file_path: str,
        original_filename: str,
    ) -> dict[str, Any]:
        """Add a file upload as a source. Returns source record; caller dispatches Celery task."""
        count = await self._store.count_sources(notebook_id, user_id)
        if count >= _MAX_SOURCES:
            raise ValueError(f"来源数量已达上限 ({_MAX_SOURCES})")

        source_id = _gen_id("src_")
        doc_id = f"nb_src_{source_id}"
        title = original_filename or file_path.rsplit("/", 1)[-1].rsplit("\\", 1)[-1]

        await self._store.create_source(
            source_id, notebook_id, user_id, "upload", title,
            doc_id=doc_id, file_path=file_path, ingest_status="pending",
        )
        return {"id": source_id, "doc_id": doc_id, "file_path": file_path, "title": title}

    async def add_source_by_paste(
        self, notebook_id: str, user_id: str, title: str, content: str,
    ) -> dict[str, Any]:
        """Add pasted text as a source. Saves text then dispatches ingest."""
        count = await self._store.count_sources(notebook_id, user_id)
        if count >= _MAX_SOURCES:
            raise ValueError(f"来源数量已达上限 ({_MAX_SOURCES})")

        source_id = _gen_id("src_")
        doc_id = f"nb_src_{source_id}"

        # Save paste text to a temp file
        import tempfile
        from pathlib import Path
        paste_dir = Path(settings.file_storage_path) / "notebook_paste"
        paste_dir.mkdir(parents=True, exist_ok=True)
        paste_file = paste_dir / f"{source_id}.txt"
        paste_file.write_text(content, encoding="utf-8")

        await self._store.create_source(
            source_id, notebook_id, user_id, "paste", title,
            doc_id=doc_id, file_path=str(paste_file), paste_text=content,
            ingest_status="pending",
        )
        return {"id": source_id, "doc_id": doc_id, "file_path": str(paste_file), "title": title}

    async def update_source(
        self, source_id: str, notebook_id: str, user_id: str, **fields: Any,
    ) -> int:
        update_fields: dict[str, Any] = {}
        if "selected" in fields and fields["selected"] is not None:
            update_fields["selected"] = int(fields["selected"])
        if "title" in fields and fields["title"] is not None:
            update_fields["title"] = fields["title"]
        return await self._store.update_source(source_id, notebook_id, user_id, **update_fields)

    async def delete_source(self, source_id: str, notebook_id: str, user_id: str) -> int:
        source = await self._store.get_source(source_id, notebook_id, user_id)
        affected = await self._store.delete_source(source_id, notebook_id, user_id)
        if affected > 0 and source and source["source_type"] in ("upload", "paste") and source.get("doc_id"):
            try:
                await self._cleanup_es_source(source["doc_id"])
            except Exception as exc:
                logger.warning("source_es_cleanup_failed", doc_id=source.get("doc_id"), error=str(exc))
        return affected

    async def retry_source_ingest(
        self, source_id: str, notebook_id: str, user_id: str,
    ) -> dict[str, Any] | None:
        """Reset failed source to pending for re-ingestion. Returns source or None."""
        source = await self._store.get_source(source_id, notebook_id, user_id)
        if source is None or source["ingest_status"] != "failed":
            return None
        await self._store.update_source(
            source_id, notebook_id, user_id,
            ingest_status="pending", ingest_error=None,
        )
        return source

    async def _cleanup_es_source(self, doc_id: str) -> None:
        """Best-effort removal of ES documents associated with a notebook source."""
        try:
            await self._es.raw.delete(index=settings.es_meta_index, id=doc_id, ignore=[404])
        except Exception:
            pass
        try:
            await self._es.raw.delete_by_query(
                index=settings.es_chunk_index,
                body={"query": {"term": {"doc_ids": doc_id}}},
                ignore=[404],
            )
        except Exception:
            pass

    # ==================================================================
    # Chat
    # ==================================================================

    async def chat(
        self,
        notebook_id: str,
        user_id: str,
        question: str,
        session_id: str = "default",
        acl_tokens: list[str] | None = None,
    ) -> AsyncIterator[ResearchChunk]:
        """Scoped chat using only notebook sources."""
        notebook = await self._store.get_notebook(notebook_id, user_id)
        if notebook is None:
            yield ResearchChunk(type="error", content="Notebook 不存在")
            return

        sources = await self._store.get_selected_sources(notebook_id, user_id)
        if not sources:
            yield ResearchChunk(
                type="text",
                content=(
                    "目前笔记本中还没有相关来源材料。\n\n"
                    "本 Notebook 的核心优势在于基于您提供的资料进行深度分析和总结。"
                    "为了能为您提供最精准的见解，建议您：\n\n"
                    "1. 在左侧**「来源」**面板上传相关的文件材料\n"
                    "2. 或通过**「文档检索」**添加系统中已有的文档\n"
                    "3. 也可以直接**粘贴文字**作为来源\n\n"
                    "添加来源后，我将仅基于这些材料为您解答问题，确保回答有据可依。"
                ),
            )
            yield ResearchChunk(type="done")
            # Still persist the user message and this guidance as assistant reply
            await self._store.create_message(
                _gen_id("msg_"), notebook_id, user_id, "user", question,
                session_id=session_id,
            )
            await self._store.create_message(
                _gen_id("msg_"), notebook_id, user_id, "assistant",
                "目前笔记本中还没有相关来源材料。请在左侧「来源」面板添加文件或文档后再提问。",
                session_id=session_id,
            )
            return

        # Split sources by type
        search_doc_ids: list[str] = []
        nb_acl_token = f"NB_{notebook_id}"
        has_ingested_sources = False

        for src in sources:
            if src["source_type"] == "search" and src.get("doc_id"):
                search_doc_ids.append(src["doc_id"])
            elif src["source_type"] in ("upload", "paste") and src.get("doc_id"):
                has_ingested_sources = True
                search_doc_ids.append(src["doc_id"])

        # Build PermissionContext with NB_ token + user's tokens
        user_tokens = list(acl_tokens or [])
        if has_ingested_sources:
            user_tokens.append(nb_acl_token)
        perm = PermissionContext(user_id=user_id, acl_tokens=user_tokens)

        # Build ResearchEngine with notebook chat system prompt
        config = notebook.get("config") or notebook.get("config_json") or {}
        engine = ResearchEngine(
            es_client=self._es,
            embedding_service=self._embedding,
            graph_service=self._graph,
            llm_client=self._llm,
            session_store=self._session_store,
        )

        # Use ResearchEngine.qa() with notebook_mode
        qa_session_id = f"nb_{notebook_id}_{session_id}"
        gen = engine.qa(
            question,
            perm,
            session_id=qa_session_id,
            seed_doc_ids=search_doc_ids,
            notebook_mode=True,
            notebook_config=config,
        )

        # Collect answer for persistence
        answer_parts: list[str] = []
        refs: list[dict[str, Any]] = []

        async for chunk in gen:
            if chunk.type == "text":
                answer_parts.append(chunk.content or "")
            elif chunk.type == "reference":
                refs.append(chunk.model_dump(exclude_none=True))
            yield chunk

        # Persist messages
        full_answer = "".join(answer_parts)
        if full_answer:
            await self._store.create_message(
                _gen_id("msg_"), notebook_id, user_id, "user", question,
                session_id=session_id,
            )
            # Generate suggestions
            suggestions = await self._generate_suggestions_from_context(sources, question, full_answer)
            await self._store.create_message(
                _gen_id("msg_"), notebook_id, user_id, "assistant", full_answer,
                session_id=session_id,
                references_json=refs if refs else None,
                suggestions_json=suggestions if suggestions else None,
            )
            # Yield suggestions as final chunk
            if suggestions:
                yield ResearchChunk(
                    type="suggestions",
                    payload={"suggestions": suggestions},
                )

    async def get_chat_history(
        self, notebook_id: str, user_id: str,
        session_id: str = "default", limit: int = 100, offset: int = 0,
    ) -> list[dict[str, Any]]:
        return await self._store.list_messages(
            notebook_id, user_id, session_id=session_id, limit=limit, offset=offset,
        )

    async def generate_suggestions(
        self, notebook_id: str, user_id: str,
    ) -> list[str]:
        """Generate suggested questions based on sources."""
        sources = await self._store.get_selected_sources(notebook_id, user_id)
        if not sources:
            return []
        return await self._generate_suggestions_from_context(sources)

    async def _generate_suggestions_from_context(
        self,
        sources: list[dict[str, Any]],
        question: str = "",
        answer: str = "",
    ) -> list[str]:
        """Generate follow-up questions using LLM."""
        summaries = []
        for src in sources[:10]:
            summary = src.get("summary") or src.get("title", "")
            if summary:
                summaries.append(f"- {summary}")
        source_text = "\n".join(summaries) if summaries else "（无来源摘要）"

        history_section = ""
        if question and answer:
            history_section = f"## 最近对话\n用户: {question[:200]}\n助手: {answer[:300]}\n"

        messages = [
            {"role": "system", "content": SUGGEST_QUESTIONS_SYSTEM},
            {"role": "user", "content": SUGGEST_QUESTIONS_USER.format(
                source_summaries=source_text,
                history_section=history_section,
            )},
        ]

        try:
            result_parts: list[str] = []
            async for token in self._llm.chat(messages, temperature=0.7, max_tokens=500):
                result_parts.append(token)
            text = "".join(result_parts).strip()
            suggestions = [line.strip().lstrip("0123456789.-、） ") for line in text.split("\n") if line.strip()]
            return suggestions[:5]
        except Exception as exc:
            logger.warning("suggest_questions_failed", error=str(exc))
            return []

    # ==================================================================
    # Output generation
    # ==================================================================

    async def generate_output(
        self,
        notebook_id: str,
        user_id: str,
        output_type: str,
        title: str = "",
        custom_instructions: str = "",
    ) -> AsyncIterator[ResearchChunk]:
        """Generate an output document from notebook sources."""
        sources = await self._store.get_selected_sources(notebook_id, user_id)
        if not sources:
            yield ResearchChunk(type="error", content="没有可用的来源。")
            return

        # Build title
        type_name = OUTPUT_TYPE_NAMES.get(output_type, "文档")
        if not title:
            title = f"{type_name} - {sources[0].get('title', 'Notebook')}"

        # Create output record with source count
        output_id = _gen_id("out_")
        source_count = len(sources)
        await self._store.create_output(
            output_id, notebook_id, user_id, output_type, title,
            context_json={"source_count": source_count, "source_titles": [s.get("title", "") for s in sources]},
        )

        yield ResearchChunk(type="thinking", content=f"正在准备生成{type_name}…")

        try:
            # Gather source content from ES
            context_text = await self._gather_source_content(sources, user_id, notebook_id)

            if not context_text:
                await self._store.update_output(
                    output_id, notebook_id, user_id,
                    status="failed", error="未能获取来源内容",
                )
                yield ResearchChunk(type="error", content="未能获取来源内容。")
                return

            # Build messages
            system_prompt = OUTPUT_FORMAT_TEMPLATES.get(output_type, OUTPUT_FORMAT_TEMPLATES["custom"])
            custom_section = f"## 用户要求\n{custom_instructions}\n" if custom_instructions else ""
            user_prompt = OUTPUT_GENERATION_USER.format(
                context_text=context_text, custom_section=custom_section,
            )

            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ]

            yield ResearchChunk(type="thinking", content=f"正在生成{type_name}…")

            # Stream LLM output
            content_parts: list[str] = []
            async for token in self._llm.chat(messages, temperature=0.4, max_tokens=4000):
                content_parts.append(token)
                yield ResearchChunk(type="text", content=token)

            content_md = "".join(content_parts)
            await self._store.update_output(
                output_id, notebook_id, user_id,
                content_md=content_md, status="completed",
            )
            yield ResearchChunk(
                type="done",
                payload={"output_id": output_id, "title": title},
            )

        except Exception as exc:
            logger.error("output_generation_failed", error=str(exc), output_id=output_id)
            await self._store.update_output(
                output_id, notebook_id, user_id,
                status="failed", error=str(exc),
            )
            yield ResearchChunk(type="error", content=f"生成失败: {exc}")

    async def _gather_source_content(
        self, sources: list[dict[str, Any]], user_id: str, notebook_id: str,
    ) -> str:
        """Fetch chunk content for all notebook sources from ES."""
        doc_ids = [s["doc_id"] for s in sources if s.get("doc_id")]
        if not doc_ids:
            return ""

        parts: list[str] = []
        for idx, doc_id in enumerate(doc_ids, 1):
            try:
                resp = await self._es.raw.search(
                    index=settings.es_chunk_index,
                    body={
                        "query": {"term": {"doc_ids": doc_id}},
                        "size": 50,
                        "sort": [{"chunk_index": {"order": "asc", "unmapped_type": "integer"}}],
                        "_source": ["content", "heading_hierarchy", "page_number"],
                    },
                )
                raw = resp if isinstance(resp, dict) else resp.body
                hits = raw.get("hits", {}).get("hits", [])
                doc_title = next(
                    (s["title"] for s in sources if s.get("doc_id") == doc_id), doc_id,
                )
                chunks_text = "\n".join(
                    h["_source"].get("content", "") for h in hits if h.get("_source", {}).get("content")
                )
                if chunks_text:
                    parts.append(f"[{idx}] 《{doc_title}》\n{chunks_text}")
            except Exception as exc:
                logger.warning("gather_source_content_failed", doc_id=doc_id, error=str(exc))

        return "\n\n---\n\n".join(parts)

    # ==================================================================
    # Source summary generation
    # ==================================================================

    async def generate_source_summary(
        self, source_id: str, notebook_id: str, user_id: str,
    ) -> str | None:
        """Generate summary for a source after ingest. Returns summary text."""
        source = await self._store.get_source(source_id, notebook_id, user_id)
        if source is None or not source.get("doc_id"):
            return None

        # Fetch first few chunks
        try:
            resp = await self._es.raw.search(
                index=settings.es_chunk_index,
                body={
                    "query": {"term": {"doc_ids": source["doc_id"]}},
                    "size": 5,
                    "sort": [{"chunk_index": {"order": "asc", "unmapped_type": "integer"}}],
                    "_source": ["content"],
                },
            )
            raw = resp if isinstance(resp, dict) else resp.body
            hits = raw.get("hits", {}).get("hits", [])
            preview = "\n".join(
                h["_source"].get("content", "")[:500] for h in hits if h.get("_source", {}).get("content")
            )[:2000]
        except Exception:
            preview = source.get("title", "")

        if not preview:
            return None

        messages = [
            {"role": "system", "content": SOURCE_SUMMARY_SYSTEM},
            {"role": "user", "content": SOURCE_SUMMARY_USER.format(
                title=source.get("title", ""), content_preview=preview,
            )},
        ]

        try:
            parts: list[str] = []
            async for token in self._llm.chat(messages, temperature=0.3, max_tokens=200):
                parts.append(token)
            summary = "".join(parts).strip()
            await self._store.update_source(
                source_id, notebook_id, user_id, summary=summary,
            )
            return summary
        except Exception as exc:
            logger.warning("source_summary_failed", error=str(exc))
            return None
