"""SQLite-backed storage for async transcription tasks."""

from __future__ import annotations

import json
import sqlite3
import threading
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Optional

_DB_DIR = Path(__file__).resolve().parents[2] / "sql_data"
_DB_PATH = _DB_DIR / "asr_tasks.db"
_DB_LOCK = threading.RLock()


def _utc_now_iso() -> str:
    return datetime.now(timezone.utc).isoformat(timespec="seconds")


def _get_conn() -> sqlite3.Connection:
    conn = sqlite3.connect(str(_DB_PATH), timeout=30)
    conn.row_factory = sqlite3.Row
    return conn


def init_task_db() -> None:
    _DB_DIR.mkdir(parents=True, exist_ok=True)
    with _DB_LOCK:
        with _get_conn() as conn:
            conn.execute("PRAGMA journal_mode=WAL;")
            conn.execute(
                """
                CREATE TABLE IF NOT EXISTS transcribe_tasks (
                    task_id TEXT PRIMARY KEY,
                    filename TEXT,
                    speaker_ids TEXT,
                    status TEXT NOT NULL,
                    result_json TEXT,
                    error_message TEXT,
                    created_at TEXT NOT NULL,
                    started_at TEXT,
                    finished_at TEXT,
                    updated_at TEXT NOT NULL
                )
                """
            )
            conn.commit()


def create_task_record(task_id: str, filename: str, speaker_ids: Optional[str]) -> None:
    now = _utc_now_iso()
    with _DB_LOCK:
        with _get_conn() as conn:
            conn.execute(
                """
                INSERT INTO transcribe_tasks (
                    task_id, filename, speaker_ids, status,
                    result_json, error_message, created_at,
                    started_at, finished_at, updated_at
                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                """,
                (
                    task_id,
                    filename,
                    speaker_ids,
                    "queued",
                    None,
                    None,
                    now,
                    None,
                    None,
                    now,
                ),
            )
            conn.commit()


def mark_task_running(task_id: str) -> None:
    now = _utc_now_iso()
    with _DB_LOCK:
        with _get_conn() as conn:
            conn.execute(
                """
                UPDATE transcribe_tasks
                SET status = ?, started_at = ?, updated_at = ?
                WHERE task_id = ?
                """,
                ("running", now, now, task_id),
            )
            conn.commit()


def mark_task_succeeded(task_id: str, result: Dict[str, Any]) -> None:
    now = _utc_now_iso()
    with _DB_LOCK:
        with _get_conn() as conn:
            conn.execute(
                """
                UPDATE transcribe_tasks
                SET status = ?, result_json = ?, error_message = ?,
                    finished_at = ?, updated_at = ?
                WHERE task_id = ?
                """,
                ("succeeded", json.dumps(result, ensure_ascii=False), None, now, now, task_id),
            )
            conn.commit()


def mark_task_failed(task_id: str, error_message: str) -> None:
    now = _utc_now_iso()
    with _DB_LOCK:
        with _get_conn() as conn:
            conn.execute(
                """
                UPDATE transcribe_tasks
                SET status = ?, error_message = ?,
                    finished_at = ?, updated_at = ?
                WHERE task_id = ?
                """,
                ("failed", error_message, now, now, task_id),
            )
            conn.commit()


def get_task_record(task_id: str) -> Optional[Dict[str, Any]]:
    with _DB_LOCK:
        with _get_conn() as conn:
            row = conn.execute(
                """
                SELECT task_id, filename, speaker_ids, status, result_json,
                       error_message, created_at, started_at, finished_at, updated_at
                FROM transcribe_tasks
                WHERE task_id = ?
                """,
                (task_id,),
            ).fetchone()

    if row is None:
        return None

    result = None
    if row["result_json"]:
        try:
            result = json.loads(row["result_json"])
        except Exception:
            result = None

    return {
        "task_id": row["task_id"],
        "filename": row["filename"],
        "speaker_ids": row["speaker_ids"],
        "status": row["status"],
        "result": result,
        "error_message": row["error_message"],
        "created_at": row["created_at"],
        "started_at": row["started_at"],
        "finished_at": row["finished_at"],
        "updated_at": row["updated_at"],
    }
