"""MySQL-backed repository for notebooks, sources, messages and outputs.

Notebook 持久化仓储。提供笔记本 CRUD、来源管理、消息存储、输出文档存储。
所有读取方法均包含 user_id 条件以保证数据隔离。
"""

from __future__ import annotations

import json
from datetime import datetime, timezone
from typing import Any

from app.infrastructure.mysql_client import MySQLClient
from app.utils.logger import get_logger

logger = get_logger(__name__)


def _utcnow() -> datetime:
    return datetime.now(timezone.utc)


def _serialize(value: Any) -> str | None:
    if value is None:
        return None
    return json.dumps(value, ensure_ascii=False)


def _deserialize_json(raw: Any) -> Any:
    if raw is None:
        return None
    if isinstance(raw, (dict, list)):
        return raw
    if isinstance(raw, (str, bytes)):
        try:
            return json.loads(raw)
        except (json.JSONDecodeError, TypeError):
            return None
    return None


# ---------------------------------------------------------------------------
# Column definitions
# ---------------------------------------------------------------------------

_NB_SUMMARY_COLUMNS = (
    "id, user_id, title, description, status, created_at, updated_at"
)

_NB_DETAIL_COLUMNS = (
    "id, user_id, title, description, status, config_json, created_at, updated_at"
)

_SOURCE_COLUMNS = (
    "id, notebook_id, user_id, source_type, title, doc_id, content_hash, "
    "file_path, ingest_status, ingest_task_id, ingest_error, metadata_json, "
    "summary, selected, sort_order, created_at, updated_at"
)

_MESSAGE_COLUMNS = (
    "id, notebook_id, user_id, session_id, role, content, "
    "references_json, graph_data_json, suggestions_json, created_at"
)

_OUTPUT_COLUMNS = (
    "id, notebook_id, user_id, output_type, title, status, "
    "error, task_id, context_json, created_at, updated_at"
)

_OUTPUT_DETAIL_COLUMNS = (
    "id, notebook_id, user_id, output_type, title, content_md, status, "
    "error, task_id, context_json, created_at, updated_at"
)


class NotebookStore:
    """MySQL-backed store for notebook tables."""

    def __init__(self, mysql: MySQLClient) -> None:
        self._db = mysql

    # ==================================================================
    # Notebooks CRUD
    # ==================================================================

    async def create_notebook(
        self,
        notebook_id: str,
        user_id: str,
        title: str,
        description: str = "",
        config_json: dict | None = None,
    ) -> None:
        await self._db.execute(
            """
            INSERT INTO notebooks (id, user_id, title, description, config_json)
            VALUES (%s, %s, %s, %s, %s)
            """,
            (notebook_id, user_id, title, description, _serialize(config_json)),
        )

    async def get_notebook(self, notebook_id: str, user_id: str) -> dict[str, Any] | None:
        row = await self._db.fetch_one(
            f"SELECT {_NB_DETAIL_COLUMNS} FROM notebooks WHERE id = %s AND user_id = %s",
            (notebook_id, user_id),
        )
        if row is None:
            return None
        row["config_json"] = _deserialize_json(row.get("config_json"))
        return row

    _SORT_WHITELIST = {"updated_at", "created_at"}

    async def list_notebooks(
        self,
        user_id: str,
        *,
        q: str | None = None,
        status: str | None = None,
        limit: int = 50,
        offset: int = 0,
    ) -> tuple[list[dict[str, Any]], int]:
        wheres = ["user_id = %s"]
        params: list[Any] = [user_id]
        if q:
            wheres.append("title LIKE %s")
            params.append(f"%{q}%")
        if status:
            wheres.append("status = %s")
            params.append(status)
        else:
            wheres.append("status = 'active'")
        where_clause = " AND ".join(wheres)

        count_row = await self._db.fetch_one(
            f"SELECT COUNT(*) AS total FROM notebooks WHERE {where_clause}",
            tuple(params),
        )
        total = int(count_row["total"]) if count_row else 0

        rows = await self._db.fetch_all(
            f"SELECT {_NB_SUMMARY_COLUMNS} FROM notebooks "
            f"WHERE {where_clause} ORDER BY updated_at DESC LIMIT %s OFFSET %s",
            (*params, limit, offset),
        )
        return rows, total

    _NB_UPDATE_WHITELIST = {"title", "description", "status", "config_json"}

    async def update_notebook(self, notebook_id: str, user_id: str, **fields: Any) -> int:
        if not fields:
            return 0
        sets: list[str] = []
        params: list[Any] = []
        for col, val in fields.items():
            if col not in self._NB_UPDATE_WHITELIST:
                continue
            if col.endswith("_json"):
                sets.append(f"{col} = %s")
                params.append(_serialize(val))
            else:
                sets.append(f"{col} = %s")
                params.append(val)
        if not sets:
            return 0
        set_clause = ", ".join(sets)
        params.extend([notebook_id, user_id])
        return await self._db.execute(
            f"UPDATE notebooks SET {set_clause} WHERE id = %s AND user_id = %s",
            tuple(params),
        )

    async def archive_notebook(self, notebook_id: str, user_id: str) -> int:
        return await self._db.execute(
            "UPDATE notebooks SET status = 'archived' WHERE id = %s AND user_id = %s",
            (notebook_id, user_id),
        )

    async def delete_notebook(self, notebook_id: str, user_id: str) -> int:
        affected = await self._db.execute(
            "DELETE FROM notebooks WHERE id = %s AND user_id = %s",
            (notebook_id, user_id),
        )
        if affected > 0:
            await self._db.execute(
                "DELETE FROM notebook_sources WHERE notebook_id = %s", (notebook_id,)
            )
            await self._db.execute(
                "DELETE FROM notebook_messages WHERE notebook_id = %s", (notebook_id,)
            )
            await self._db.execute(
                "DELETE FROM notebook_outputs WHERE notebook_id = %s", (notebook_id,)
            )
        return affected

    # ==================================================================
    # Sources CRUD
    # ==================================================================

    async def create_source(
        self,
        source_id: str,
        notebook_id: str,
        user_id: str,
        source_type: str,
        title: str,
        *,
        doc_id: str | None = None,
        content_hash: str | None = None,
        file_path: str | None = None,
        paste_text: str | None = None,
        ingest_status: str = "pending",
        ingest_task_id: str | None = None,
        metadata_json: dict | None = None,
        sort_order: int = 0,
    ) -> None:
        await self._db.execute(
            """
            INSERT INTO notebook_sources
                (id, notebook_id, user_id, source_type, title, doc_id,
                 content_hash, file_path, paste_text, ingest_status,
                 ingest_task_id, metadata_json, sort_order)
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
            """,
            (
                source_id, notebook_id, user_id, source_type, title, doc_id,
                content_hash, file_path, paste_text, ingest_status,
                ingest_task_id, _serialize(metadata_json), sort_order,
            ),
        )

    async def get_source(
        self, source_id: str, notebook_id: str, user_id: str,
    ) -> dict[str, Any] | None:
        row = await self._db.fetch_one(
            f"SELECT {_SOURCE_COLUMNS} FROM notebook_sources "
            f"WHERE id = %s AND notebook_id = %s AND user_id = %s",
            (source_id, notebook_id, user_id),
        )
        if row:
            row["metadata_json"] = _deserialize_json(row.get("metadata_json"))
        return row

    async def list_sources(
        self, notebook_id: str, user_id: str,
    ) -> list[dict[str, Any]]:
        rows = await self._db.fetch_all(
            f"SELECT {_SOURCE_COLUMNS} FROM notebook_sources "
            f"WHERE notebook_id = %s AND user_id = %s ORDER BY sort_order ASC, created_at ASC",
            (notebook_id, user_id),
        )
        for row in rows:
            row["metadata_json"] = _deserialize_json(row.get("metadata_json"))
        return rows

    async def count_sources(self, notebook_id: str, user_id: str) -> int:
        row = await self._db.fetch_one(
            "SELECT COUNT(*) AS cnt FROM notebook_sources WHERE notebook_id = %s AND user_id = %s",
            (notebook_id, user_id),
        )
        return int(row["cnt"]) if row else 0

    _SRC_UPDATE_WHITELIST = {
        "title", "doc_id", "content_hash", "ingest_status", "ingest_task_id",
        "ingest_error", "metadata_json", "summary", "selected", "sort_order",
    }

    async def update_source(self, source_id: str, notebook_id: str, user_id: str, **fields: Any) -> int:
        if not fields:
            return 0
        sets: list[str] = []
        params: list[Any] = []
        for col, val in fields.items():
            if col not in self._SRC_UPDATE_WHITELIST:
                continue
            if col.endswith("_json"):
                sets.append(f"{col} = %s")
                params.append(_serialize(val))
            else:
                sets.append(f"{col} = %s")
                params.append(val)
        if not sets:
            return 0
        set_clause = ", ".join(sets)
        params.extend([source_id, notebook_id, user_id])
        return await self._db.execute(
            f"UPDATE notebook_sources SET {set_clause} "
            f"WHERE id = %s AND notebook_id = %s AND user_id = %s",
            tuple(params),
        )

    async def delete_source(self, source_id: str, notebook_id: str, user_id: str) -> int:
        return await self._db.execute(
            "DELETE FROM notebook_sources WHERE id = %s AND notebook_id = %s AND user_id = %s",
            (source_id, notebook_id, user_id),
        )

    async def get_selected_sources(
        self, notebook_id: str, user_id: str,
    ) -> list[dict[str, Any]]:
        """Get only selected and completed sources for chat context."""
        rows = await self._db.fetch_all(
            f"SELECT {_SOURCE_COLUMNS} FROM notebook_sources "
            f"WHERE notebook_id = %s AND user_id = %s AND selected = 1 AND ingest_status = 'completed' "
            f"ORDER BY sort_order ASC, created_at ASC",
            (notebook_id, user_id),
        )
        for row in rows:
            row["metadata_json"] = _deserialize_json(row.get("metadata_json"))
        return rows

    # ==================================================================
    # Messages CRUD
    # ==================================================================

    async def create_message(
        self,
        message_id: str,
        notebook_id: str,
        user_id: str,
        role: str,
        content: str,
        *,
        session_id: str = "default",
        references_json: list | None = None,
        graph_data_json: dict | None = None,
        suggestions_json: list | None = None,
    ) -> None:
        await self._db.execute(
            """
            INSERT INTO notebook_messages
                (id, notebook_id, user_id, session_id, role, content,
                 references_json, graph_data_json, suggestions_json)
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
            """,
            (
                message_id, notebook_id, user_id, session_id, role, content,
                _serialize(references_json), _serialize(graph_data_json),
                _serialize(suggestions_json),
            ),
        )

    async def list_messages(
        self,
        notebook_id: str,
        user_id: str,
        *,
        session_id: str = "default",
        limit: int = 100,
        offset: int = 0,
    ) -> list[dict[str, Any]]:
        rows = await self._db.fetch_all(
            f"SELECT {_MESSAGE_COLUMNS} FROM notebook_messages "
            f"WHERE notebook_id = %s AND user_id = %s AND session_id = %s "
            f"ORDER BY created_at ASC LIMIT %s OFFSET %s",
            (notebook_id, user_id, session_id, limit, offset),
        )
        for row in rows:
            row["references_json"] = _deserialize_json(row.get("references_json"))
            row["graph_data_json"] = _deserialize_json(row.get("graph_data_json"))
            row["suggestions_json"] = _deserialize_json(row.get("suggestions_json"))
        return rows

    async def count_messages(self, notebook_id: str, user_id: str) -> int:
        row = await self._db.fetch_one(
            "SELECT COUNT(*) AS cnt FROM notebook_messages WHERE notebook_id = %s AND user_id = %s",
            (notebook_id, user_id),
        )
        return int(row["cnt"]) if row else 0

    async def delete_messages(self, notebook_id: str, user_id: str, session_id: str = "default") -> int:
        return await self._db.execute(
            "DELETE FROM notebook_messages WHERE notebook_id = %s AND user_id = %s AND session_id = %s",
            (notebook_id, user_id, session_id),
        )

    # ==================================================================
    # Outputs CRUD
    # ==================================================================

    async def create_output(
        self,
        output_id: str,
        notebook_id: str,
        user_id: str,
        output_type: str,
        title: str,
        *,
        task_id: str | None = None,
        context_json: dict | None = None,
    ) -> None:
        await self._db.execute(
            """
            INSERT INTO notebook_outputs
                (id, notebook_id, user_id, output_type, title, task_id, context_json)
            VALUES (%s, %s, %s, %s, %s, %s, %s)
            """,
            (output_id, notebook_id, user_id, output_type, title, task_id, _serialize(context_json)),
        )

    async def get_output(
        self, output_id: str, notebook_id: str, user_id: str,
    ) -> dict[str, Any] | None:
        row = await self._db.fetch_one(
            f"SELECT {_OUTPUT_DETAIL_COLUMNS} FROM notebook_outputs "
            f"WHERE id = %s AND notebook_id = %s AND user_id = %s",
            (output_id, notebook_id, user_id),
        )
        if row:
            row["context_json"] = _deserialize_json(row.get("context_json"))
        return row

    async def list_outputs(
        self, notebook_id: str, user_id: str,
    ) -> list[dict[str, Any]]:
        rows = await self._db.fetch_all(
            f"SELECT {_OUTPUT_COLUMNS} FROM notebook_outputs "
            f"WHERE notebook_id = %s AND user_id = %s ORDER BY created_at DESC",
            (notebook_id, user_id),
        )
        for row in rows:
            row["context_json"] = _deserialize_json(row.get("context_json"))
        return rows

    _OUT_UPDATE_WHITELIST = {"title", "content_md", "status", "error", "task_id", "context_json"}

    async def update_output(self, output_id: str, notebook_id: str, user_id: str, **fields: Any) -> int:
        if not fields:
            return 0
        sets: list[str] = []
        params: list[Any] = []
        for col, val in fields.items():
            if col not in self._OUT_UPDATE_WHITELIST:
                continue
            if col.endswith("_json"):
                sets.append(f"{col} = %s")
                params.append(_serialize(val))
            else:
                sets.append(f"{col} = %s")
                params.append(val)
        if not sets:
            return 0
        set_clause = ", ".join(sets)
        params.extend([output_id, notebook_id, user_id])
        return await self._db.execute(
            f"UPDATE notebook_outputs SET {set_clause} "
            f"WHERE id = %s AND notebook_id = %s AND user_id = %s",
            tuple(params),
        )

    async def delete_output(self, output_id: str, notebook_id: str, user_id: str) -> int:
        return await self._db.execute(
            "DELETE FROM notebook_outputs WHERE id = %s AND notebook_id = %s AND user_id = %s",
            (output_id, notebook_id, user_id),
        )
