"""MySQL-backed repository for research records, reports and runs.

研究记录持久化仓储。提供记录 CRUD、报告独立存储、运行实例追踪。
所有读取方法均包含 user_id 条件以保证数据隔离。
"""

from __future__ import annotations

import json
from datetime import datetime, timezone
from typing import Any

from app.infrastructure.mysql_client import MySQLClient
from app.utils.logger import get_logger

logger = get_logger(__name__)

_SUMMARY_COLUMNS = (
    "id, user_id, title, mode, status, output_template, summary, "
    "archived, version_no, created_at, updated_at"
)

_DETAIL_COLUMNS = (
    "id, user_id, session_id, title, mode, status, output_template, summary, "
    "archived, task_json, plan_json, references_json, imported_items_json, "
    "clarification_messages_json, chat_messages_json, "
    "notes, parent_record_id, root_record_id, version_no, last_error, "
    "created_at, updated_at, completed_at"
)

# States that allow deletion (C17 alignment)
_DELETABLE_STATES = ("draft", "planned", "failed")


def _utcnow() -> datetime:
    return datetime.now(timezone.utc)


def _serialize(value: Any) -> str | None:
    if value is None:
        return None
    return json.dumps(value, ensure_ascii=False)


def _deserialize_json(raw: Any) -> Any:
    """Safely deserialize a JSON column value that may already be parsed."""
    if raw is None:
        return None
    if isinstance(raw, (dict, list)):
        return raw
    if isinstance(raw, (str, bytes)):
        try:
            return json.loads(raw)
        except (json.JSONDecodeError, TypeError):
            return None
    return None


class ResearchRecordStore:
    """MySQL-backed store for research_records and research_record_reports."""

    def __init__(self, mysql: MySQLClient) -> None:
        self._db = mysql

    # ------------------------------------------------------------------
    # Records CRUD
    # ------------------------------------------------------------------

    async def create_record(
        self,
        record_id: str,
        user_id: str,
        title: str,
        mode: str = "deep",
        output_template: str = "comprehensive",
        task_json: dict | None = None,
        imported_items_json: list | None = None,
        parent_record_id: str | None = None,
        root_record_id: str | None = None,
        version_no: int = 1,
    ) -> None:
        await self._db.execute(
            """
            INSERT INTO research_records
                (id, user_id, title, mode, status, output_template,
                 task_json, imported_items_json,
                 parent_record_id, root_record_id, version_no)
            VALUES (%s, %s, %s, %s, 'draft', %s, %s, %s, %s, %s, %s)
            """,
            (
                record_id,
                user_id,
                title,
                mode,
                output_template,
                _serialize(task_json),
                _serialize(imported_items_json),
                parent_record_id,
                root_record_id or record_id,
                version_no,
            ),
        )

    async def get_record(self, record_id: str, user_id: str) -> dict[str, Any] | None:
        row = await self._db.fetch_one(
            f"SELECT {_DETAIL_COLUMNS} FROM research_records WHERE id = %s AND user_id = %s",
            (record_id, user_id),
        )
        if row is None:
            return None
        for col in ("task_json", "plan_json", "references_json", "imported_items_json",
                    "clarification_messages_json", "chat_messages_json"):
            row[col] = _deserialize_json(row.get(col))
        return row

    # Whitelist for sort columns to prevent SQL injection
    _SORT_WHITELIST = {"updated_at", "created_at"}

    async def list_records(
        self,
        user_id: str,
        *,
        q: str | None = None,
        status: str | None = None,
        mode: str | None = None,
        output_template: str | None = None,
        created_after: str | None = None,
        created_before: str | None = None,
        archived: bool = False,
        sort_by: str = "updated_at",
        sort_order: str = "desc",
        limit: int = 50,
        offset: int = 0,
    ) -> tuple[list[dict[str, Any]], int]:
        wheres = ["user_id = %s", "archived = %s"]
        params: list[Any] = [user_id, int(archived)]
        if q:
            wheres.append("title LIKE %s")
            params.append(f"%{q}%")
        if status:
            wheres.append("status = %s")
            params.append(status)
        if mode:
            wheres.append("mode = %s")
            params.append(mode)
        if output_template:
            wheres.append("output_template = %s")
            params.append(output_template)
        if created_after:
            wheres.append("created_at >= %s")
            params.append(created_after)
        if created_before:
            wheres.append("created_at <= %s")
            params.append(created_before)
        where_clause = " AND ".join(wheres)

        # Safe sort (whitelist validated)
        col = sort_by if sort_by in self._SORT_WHITELIST else "updated_at"
        direction = "ASC" if sort_order.lower() == "asc" else "DESC"

        count_row = await self._db.fetch_one(
            f"SELECT COUNT(*) AS total FROM research_records WHERE {where_clause}",
            tuple(params),
        )
        total = int(count_row["total"]) if count_row else 0

        rows = await self._db.fetch_all(
            f"SELECT {_SUMMARY_COLUMNS} FROM research_records "
            f"WHERE {where_clause} ORDER BY {col} {direction} LIMIT %s OFFSET %s",
            (*params, limit, offset),
        )
        return rows, total

    async def update_record(self, record_id: str, user_id: str, **fields: Any) -> int:
        if not fields:
            return 0
        sets: list[str] = []
        params: list[Any] = []
        for col, val in fields.items():
            if col.endswith("_json"):
                sets.append(f"{col} = %s")
                params.append(_serialize(val))
            else:
                sets.append(f"{col} = %s")
                params.append(val)
        set_clause = ", ".join(sets)
        params.extend([record_id, user_id])
        return await self._db.execute(
            f"UPDATE research_records SET {set_clause} WHERE id = %s AND user_id = %s",
            tuple(params),
        )

    async def touch_record(self, record_id: str, user_id: str) -> int:
        """Only update updated_at timestamp."""
        return await self._db.execute(
            "UPDATE research_records SET updated_at = CURRENT_TIMESTAMP WHERE id = %s AND user_id = %s",
            (record_id, user_id),
        )

    async def list_versions(
        self, root_record_id: str, user_id: str, limit: int = 50,
    ) -> list[dict[str, Any]]:
        """Return all records sharing the same root_record_id (I2: use _SUMMARY_COLUMNS)."""
        return await self._db.fetch_all(
            f"SELECT {_SUMMARY_COLUMNS} FROM research_records "
            f"WHERE root_record_id = %s AND user_id = %s "
            f"ORDER BY version_no ASC LIMIT %s",
            (root_record_id, user_id, limit),
        )

    async def archive_record(self, record_id: str, user_id: str) -> int:
        # M5: block archiving running/pending records
        return await self._db.execute(
            "UPDATE research_records SET archived = 1 "
            "WHERE id = %s AND user_id = %s AND status NOT IN ('running')",
            (record_id, user_id),
        )

    async def delete_record(self, record_id: str, user_id: str) -> int:
        """Delete a record and its sub-table rows (H5).

        Only allowed for draft/planned/failed states.
        """
        placeholders = ", ".join(["%s"] * len(_DELETABLE_STATES))
        affected = await self._db.execute(
            f"DELETE FROM research_records WHERE id = %s AND user_id = %s AND status IN ({placeholders})",
            (record_id, user_id, *_DELETABLE_STATES),
        )
        if affected > 0:
            # Cascade-delete orphan sub-table rows
            await self._db.execute(
                "DELETE FROM research_record_reports WHERE record_id = %s",
                (record_id,),
            )
            await self._db.execute(
                "DELETE FROM research_record_runs WHERE record_id = %s",
                (record_id,),
            )
        return affected

    # ------------------------------------------------------------------
    # Report sub-table
    # ------------------------------------------------------------------

    async def get_report(self, record_id: str, user_id: str) -> dict[str, Any] | None:
        row = await self._db.fetch_one(
            """
            SELECT r.record_id, r.report_json, r.final_document_md, r.updated_at
            FROM research_record_reports r
            JOIN research_records rec ON rec.id = r.record_id
            WHERE r.record_id = %s AND rec.user_id = %s
            """,
            (record_id, user_id),
        )
        if row is None:
            return None
        row["report_json"] = _deserialize_json(row.get("report_json"))
        return row

    async def save_report(self, record_id: str, report_json: dict) -> None:
        final_doc = report_json.get("final_document_md")
        await self._db.execute(
            """
            INSERT INTO research_record_reports (record_id, report_json, final_document_md)
            VALUES (%s, %s, %s)
            ON DUPLICATE KEY UPDATE report_json = VALUES(report_json),
                                    final_document_md = VALUES(final_document_md),
                                    updated_at = CURRENT_TIMESTAMP
            """,
            (record_id, _serialize(report_json), final_doc),
        )


class ResearchRecordRunStore:
    """MySQL-backed store for research_record_runs."""

    def __init__(self, mysql: MySQLClient) -> None:
        self._db = mysql

    async def create_run(
        self,
        run_id: str,
        record_id: str,
        user_id: str,
        run_type: str,
        section_title: str | None = None,
    ) -> None:
        await self._db.execute(
            """
            INSERT INTO research_record_runs
                (run_id, record_id, user_id, run_type, status, section_title)
            VALUES (%s, %s, %s, %s, 'pending', %s)
            """,
            (run_id, record_id, user_id, run_type, section_title),
        )

    async def get_run(self, run_id: str, user_id: str) -> dict[str, Any] | None:
        return await self._db.fetch_one(
            "SELECT * FROM research_record_runs WHERE run_id = %s AND user_id = %s",
            (run_id, user_id),
        )

    async def get_run_by_record(
        self, run_id: str, record_id: str, user_id: str
    ) -> dict[str, Any] | None:
        """Fetch run with full three-level ownership check."""
        return await self._db.fetch_one(
            "SELECT * FROM research_record_runs "
            "WHERE run_id = %s AND record_id = %s AND user_id = %s",
            (run_id, record_id, user_id),
        )

    async def get_latest_run(
        self, record_id: str, user_id: str
    ) -> dict[str, Any] | None:
        return await self._db.fetch_one(
            "SELECT * FROM research_record_runs "
            "WHERE record_id = %s AND user_id = %s "
            "ORDER BY created_at DESC, run_id DESC LIMIT 1",
            (record_id, user_id),
        )

    async def update_run(self, run_id: str, **fields: Any) -> int:
        if not fields:
            return 0
        sets = [f"{col} = %s" for col in fields]
        params = list(fields.values())
        params.append(run_id)
        return await self._db.execute(
            f"UPDATE research_record_runs SET {', '.join(sets)} WHERE run_id = %s",
            tuple(params),
        )
