"""Per-site task queue for crawl jobs — prevents concurrent hits on a site.

Design:
  • A submit() call returns immediately with a job_id. The job is appended
    to the per-site/rate-group FIFO; one worker task drains each FIFO.
  • Concurrency across unrelated sites is unlimited; within one rate group
    it's strictly 1.
  • State is in-process memory. Lost on container restart — that's fine for
    v1 (jobs are idempotent; scheduler will re-trigger them next tick).
  • cancel() removes a queued job or sets the "stop requested" flag on a
    running one. Hard-kill of the sync crawl_target is intentionally not
    supported (would need a subprocess).

Call sites:
  • admin/targets.run_target  → submit()
  • scheduler cron fn         → submit()  (via HTTP to api; next step)
  • admin/jobs.* endpoints     → list() + cancel()
"""
from __future__ import annotations

import asyncio
import logging
import time
import traceback
import uuid
from dataclasses import asdict, dataclass, field
from typing import Literal, Optional
from urllib.parse import urlparse

log = logging.getLogger(__name__)


# Cache: site_code → queue key. Multiple site_codes can share one key so we
# never hit the same host/security cluster concurrently.
#
# Most sites use their base_url host as the key. Guangdong provincial gkmlpt
# sites are stricter: many different *.gd.gov.cn hosts appear to sit behind
# the same gateway/WAF policy, so they intentionally share one FIFO.
# Cache is invalidated lazily: a site row's base_url rarely changes; if it
# does, an api restart re-primes the cache.
_HOST_CACHE: dict[str, str] = {}


def _rate_key_for_site(site_code: str) -> str:
    """Resolve site_code → rate queue key. Falls back
    to site_code itself if the lookup fails (unknown site, DB hiccup), so
    queue serialization degrades gracefully to per-site rather than
    silently colliding everything into one queue."""
    cached = _HOST_CACHE.get(site_code)
    if cached is not None:
        return cached
    try:
        from govcrawler.db import get_sessionmaker
        from govcrawler.models import CrawlSite
        S = get_sessionmaker()
        with S() as s:
            row = s.query(CrawlSite).filter_by(site_code=site_code).first()
            if row and row.base_url:
                host = urlparse(row.base_url).netloc.lower() or site_code
                if (
                    row.cms_adapter == "gkmlpt"
                    and (host == "gd.gov.cn" or host.endswith(".gd.gov.cn"))
                ):
                    key = "gd_gkmlpt_shared"
                else:
                    key = host
                _HOST_CACHE[site_code] = key
                return key
    except Exception:
        pass
    _HOST_CACHE[site_code] = site_code
    return site_code


def _host_for_site(site_code: str) -> str:
    """Backward-compatible alias used for persisted crawl_job.host.

    Existing DB column name is `host`, but the value is now a rate queue key.
    """
    return _rate_key_for_site(site_code)

JobStatus = Literal[
    "queued",       # sitting in the site's FIFO
    "running",      # worker picked it up
    "done",         # completed successfully
    "failed",       # crawl_target raised
    "cancelled",    # dequeued before running, or terminate-requested mid-run
]


def _epoch_iso(ts: float | None) -> str | None:
    if ts is None:
        return None
    from govcrawler.timeutil import epoch_iso_cn
    return epoch_iso_cn(ts)


def _epoch_db_dt(ts: float | None):
    from govcrawler.timeutil import epoch_db_dt_cn
    return epoch_db_dt_cn(ts)


@dataclass
class JobInfo:
    job_id: str
    site_code: str
    target_code: str
    source: str                             # "manual" | "schedule" | "retry"
    status: JobStatus = "queued"
    enqueued_at: float = field(default_factory=time.time)
    started_at: Optional[float] = None
    finished_at: Optional[float] = None
    error_msg: Optional[str] = None
    result: Optional[dict] = None
    # When true, pipeline runs with stop_on_duplicate=False — used for
    # backfill / 全量抓取 to ignore the historical-boundary early-stop and
    # walk every page even if all entries are dedup-skipped. Default False
    # preserves the increment-friendly behavior for cron and casual ▶.
    force: bool = False
    # Initial checkpoint used when a new manual job is explicitly queued to
    # continue from a previous run. Once persisted, crawl_target reads the
    # durable crawl_job.last_completed_page value before starting.
    initial_last_completed_page: int = 0
    # Live progress page. This is informational only; resume still uses
    # last_completed_page from the durable row.
    current_page: int = 0
    # Set to True when operator requests termination; pipeline code can poll
    # this from queue.job_for(job_id).stop_requested if it grows cancellation
    # hooks. For v1 the flag just marks the job as "cancelled" once the
    # current sync run returns.
    stop_requested: bool = False

    def to_dict(self) -> dict:
        now = time.time()
        running_for = None
        if self.started_at and not self.finished_at:
            running_for = now - self.started_at
        elif self.started_at and self.finished_at:
            running_for = self.finished_at - self.started_at
        return {
            "job_id": self.job_id,
            "site_code": self.site_code,
            "target_code": self.target_code,
            "source": self.source,
            "status": self.status,
            "stop_requested": self.stop_requested,
            "enqueued_at": _epoch_iso(self.enqueued_at),
            "started_at": _epoch_iso(self.started_at),
            "finished_at": _epoch_iso(self.finished_at),
            "running_for_sec": round(running_for, 1) if running_for else None,
            "queued_for_sec": round(
                (self.started_at or now) - self.enqueued_at, 1
            ) if self.enqueued_at else None,
            "error_msg": self.error_msg,
            "result": self.result,
            "initial_last_completed_page": self.initial_last_completed_page,
            "current_page": self.current_page,
        }


# Cap the number of "done/failed/cancelled" jobs we keep around for the UI
# history pane — prevents unbounded memory growth on a long-running api.
HISTORY_KEEP = 200


class TaskQueue:
    def __init__(self) -> None:
        self._jobs: dict[str, JobInfo] = {}
        # Queues keyed by HOST (netloc), not site_code — so all sites
        # sharing a base_url (gd_gkmlpt + gd_wjk on www.gd.gov.cn) serialize
        # through one FIFO. Same-host parallel hits trip WAF rate limiters.
        self._host_queues: dict[str, asyncio.Queue[str]] = {}
        self._workers: dict[str, asyncio.Task] = {}
        self._history_order: list[str] = []  # job_ids in completion order (FIFO trim)
        self._lock = asyncio.Lock()

    # ---------- DB persistence helpers ----------

    @staticmethod
    def _db_upsert_job(j: "JobInfo", *, host: str) -> None:
        """Mirror an in-memory JobInfo to the crawl_job DB row.

        Called on every state transition (queued → running → done/failed/
        cancelled). Best-effort: a DB error here must NOT take down the
        in-memory queue, so we swallow + log. The row may temporarily be
        out of sync until the next transition fixes it."""
        try:
            from govcrawler.db import get_sessionmaker
            from govcrawler.models import CrawlJob
            S = get_sessionmaker()
            with S() as s:
                row = s.get(CrawlJob, j.job_id)
                if row is None:
                    row = CrawlJob(
                        job_id=j.job_id,
                        host=host,
                        site_code=j.site_code,
                        target_code=j.target_code,
                        source=j.source,
                        status=j.status,
                        force=j.force,
                        stop_requested=j.stop_requested,
                        last_completed_page=max(0, int(j.initial_last_completed_page or 0)),
                        current_page=max(0, int(j.current_page or 0)),
                        attempt_count=0,
                        enqueued_at=_epoch_db_dt(j.enqueued_at),
                    )
                    s.add(row)
                else:
                    row.status = j.status
                    row.stop_requested = j.stop_requested
                    row.current_page = max(0, int(j.current_page or 0))
                row.started_at = _epoch_db_dt(j.started_at)
                row.finished_at = _epoch_db_dt(j.finished_at)
                row.error_msg = j.error_msg
                row.result_json = j.result
                s.commit()
        except Exception:
            log.exception("crawl_job upsert failed job=%s", j.job_id)

    async def restore_from_db(self) -> dict[str, int]:
        """Boot-time recovery: any 'queued' or 'running' row whose process
        died is re-enqueued into the in-memory FIFO so a fresh worker
        drains it. The crawl is idempotent (dedup early-stop handles the
        already-fetched URLs in zero time), so resuming is safer than
        marking failed — operator's manual ▶ shouldn't get burned by a
        deploy that happens to overlap.

        Only jobs that exceeded MAX_RESTART_RECOVERY (3) get marked failed
        permanently — that bounds the loop in case some job keeps killing
        the api process repeatedly.

        Idempotent — safe to call once at api startup. Returns counts so
        the startup logger can show what was restored."""
        from govcrawler.db import get_sessionmaker
        from govcrawler.models import CrawlJob
        from govcrawler.timeutil import now_cn_naive

        MAX_RESTART_RECOVERY = 3

        recovered = 0
        requeued = 0
        permanently_failed = 0

        S = get_sessionmaker()
        with S() as s:
            # 1. running rows are orphans — their process is gone. Bump
            # attempt_count and recover; if that crosses MAX, give up.
            for row in s.query(CrawlJob).filter_by(status="running").all():
                if (row.attempt_count or 0) + 1 >= MAX_RESTART_RECOVERY:
                    row.status = "failed"
                    row.finished_at = now_cn_naive()
                    row.error_msg = (row.error_msg or "") + (
                        " | "
                        if row.error_msg else ""
                    ) + (
                        f"abandoned after {MAX_RESTART_RECOVERY} restart "
                        "recoveries (job kept crashing the api?)"
                    )
                    permanently_failed += 1
                    continue
                row.status = "queued"
                row.attempt_count = (row.attempt_count or 0) + 1
                # Clear started_at so the next run starts a fresh stopwatch.
                row.started_at = None
                # Annotate (not overwrite) error_msg so audit trail
                # preserves prior failures.
                annotation = f"restart_during_run (recovery attempt {row.attempt_count}/{MAX_RESTART_RECOVERY})"
                row.error_msg = (
                    f"{row.error_msg} | {annotation}"
                    if row.error_msg else annotation
                )
                recovered += 1
            s.commit()

            # 2. Re-enqueue all queued jobs (including those we just
            # converted from running). Original FIFO order.
            queued_rows = (
                s.query(CrawlJob)
                .filter_by(status="queued")
                .order_by(CrawlJob.enqueued_at.asc())
                .all()
            )
            for row in queued_rows:
                ji = JobInfo(
                    job_id=row.job_id,
                    site_code=row.site_code,
                    target_code=row.target_code,
                    source=row.source,
                    force=bool(row.force),
                    stop_requested=bool(row.stop_requested),
                    current_page=max(0, int(row.current_page or 0)),
                    enqueued_at=row.enqueued_at.timestamp() if row.enqueued_at else time.time(),
                )
                self._jobs[row.job_id] = ji
                host = row.host or _host_for_site(row.site_code)
                async with self._lock:
                    q = self._host_queues.get(host)
                    if q is None:
                        q = asyncio.Queue()
                        self._host_queues[host] = q
                        self._workers[host] = asyncio.create_task(
                            self._worker_loop(host, q),
                            name=f"taskq-{host}",
                        )
                await q.put(row.job_id)
                requeued += 1

        log.info(
            "task_queue restore: running→requeued=%d, queued_requeued=%d, "
            "permanently_failed=%d",
            recovered, requeued - recovered, permanently_failed,
        )
        return {
            "recovered": recovered,
            "requeued": requeued,
            "permanently_failed": permanently_failed,
        }

    # ---------- public API ----------

    async def submit(
        self, *, site_code: str, target_code: str, source: str = "manual",
        force: bool = False, resume_from_page: int = 0,
    ) -> str:
        """Enqueue a crawl job. Returns job_id. Spawns the per-host worker if
        this is the first job for that host. Persists to crawl_job table
        so a container restart doesn't lose the job."""
        host = _host_for_site(site_code)
        async with self._lock:
            job = JobInfo(
                job_id=uuid.uuid4().hex[:12],
                site_code=site_code,
                target_code=target_code,
                source=source,
                force=force,
                initial_last_completed_page=max(0, int(resume_from_page or 0)),
            )
            self._jobs[job.job_id] = job
            q = self._host_queues.get(host)
            if q is None:
                q = asyncio.Queue()
                self._host_queues[host] = q
                self._workers[host] = asyncio.create_task(
                    self._worker_loop(host, q),
                    name=f"taskq-{host}",
                )
            await q.put(job.job_id)
        # Persist AFTER the in-memory enqueue. If DB write fails the
        # job still runs (we'd rather lose durability than lose the run);
        # the next status transition will retry the upsert.
        self._db_upsert_job(job, host=host)
        log.info(
            "queue submit job=%s host=%s site=%s target=%s source=%s resume_from_page=%s",
            job.job_id, host, site_code, target_code, source, job.initial_last_completed_page,
        )
        return job.job_id

    def job(self, job_id: str) -> JobInfo | None:
        return self._jobs.get(job_id)

    def list_jobs(
        self, *, site: str | None = None, status: str | None = None,
        include_history: bool = True, limit: int = 200,
    ) -> list[JobInfo]:
        out: list[JobInfo] = []
        for j in self._jobs.values():
            if site and j.site_code != site:
                continue
            if status and j.status != status:
                continue
            if not include_history and j.status in ("done", "failed", "cancelled"):
                continue
            out.append(j)
        # Sort: running first, then queued (by enqueued_at asc), then history (desc finished_at)
        def sort_key(j: JobInfo):
            rank = {"running": 0, "queued": 1, "cancelled": 2, "failed": 2, "done": 3}.get(j.status, 9)
            ts = j.finished_at or j.started_at or j.enqueued_at
            # queued → ascending (older first); others → descending (newest first)
            if j.status == "queued":
                return (rank, ts)
            return (rank, -ts)
        out.sort(key=sort_key)
        return out[:limit]

    def queue_summary(self) -> list[dict]:
        """One row per active host. With host-based queueing, multiple
        site_codes (e.g. gd_gkmlpt, gd_wjk) sharing www.gd.gov.cn appear in
        the same row's site_codes list."""
        out = []
        for host, q in self._host_queues.items():
            host_jobs = [
                j for j in self._jobs.values()
                if _host_for_site(j.site_code) == host
            ]
            running = next((j for j in host_jobs if j.status == "running"), None)
            queued = [j for j in host_jobs if j.status == "queued"]
            site_codes = sorted({j.site_code for j in host_jobs})
            out.append({
                "host": host,
                "site_codes": site_codes,
                # Back-compat: UI may still reference site_code; expose the
                # running job's site_code (or first known on this host).
                "site_code": running.site_code if running else (site_codes[0] if site_codes else host),
                "running_job_id": running.job_id if running else None,
                "running_target": running.target_code if running else None,
                "running_site": running.site_code if running else None,
                "running_for_sec": round(time.time() - running.started_at, 1)
                    if running and running.started_at else None,
                "queued_count": len(queued),
                "next_queued_target": queued[0].target_code if queued else None,
                "next_queued_site": queued[0].site_code if queued else None,
            })
        return out

    async def cancel(self, job_id: str) -> dict:
        """Cancel a queued job (remove from queue) or flag a running one.

        For queued jobs: flip status → cancelled; worker will skip when it pops.
        For running jobs: set stop_requested = True. The current crawl run
        will continue (sync function), but status will transition to
        cancelled on completion instead of done.
        """
        j = self._jobs.get(job_id)
        if j is None:
            return {"ok": False, "reason": "not_found"}
        if j.status == "queued":
            j.status = "cancelled"
            j.finished_at = time.time()
            self._db_upsert_job(j, host=_host_for_site(j.site_code))
            log.info("queue cancel queued job=%s site=%s", job_id, j.site_code)
            return {"ok": True, "cancelled": "queued"}
        if j.status == "running":
            j.stop_requested = True
            self._db_upsert_job(j, host=_host_for_site(j.site_code))
            log.info("queue cancel running (flag) job=%s site=%s", job_id, j.site_code)
            return {"ok": True, "cancelled": "running_flagged"}
        return {"ok": False, "reason": f"cannot cancel status={j.status}"}

    # ---------- internals ----------

    async def _worker_loop(self, host: str, q: asyncio.Queue[str]) -> None:
        log.info("taskq worker started host=%s", host)
        while True:
            job_id = await q.get()
            j = self._jobs.get(job_id)
            if j is None or j.status == "cancelled":
                q.task_done()
                continue
            j.status = "running"
            j.started_at = time.time()
            self._db_upsert_job(j, host=host)
            log.info("taskq START job=%s host=%s site=%s target=%s (queued for %.1fs)",
                     job_id, host, j.site_code, j.target_code,
                     j.started_at - j.enqueued_at)
            try:
                # Run the (sync) crawl in a thread so we don't block the
                # event loop. Pass a thread-safe stop_check closure that
                # peeks at j.stop_requested — pipeline polls it between
                # list pages and between articles so /cancel actually
                # halts the in-flight crawl mid-flight, not just labels
                # the row 'cancelled' after it finishes naturally.
                from govcrawler.pipeline import crawl_target

                # Resume from saved page checkpoint if the target opted in.
                # crawl_target reads track_checkpoint flag itself; we just
                # supply the last completed page from the durable row.
                resume_page = 0
                try:
                    from govcrawler.db import get_sessionmaker as _S
                    from govcrawler.models import CrawlJob as _CJ
                    with _S()() as _s:
                        _row = _s.get(_CJ, j.job_id)
                        if _row is not None:
                            resume_page = int(_row.last_completed_page or 0)
                except Exception:
                    pass

                result = await asyncio.to_thread(
                    crawl_target,
                    j.target_code,
                    stop_on_duplicate=not j.force,
                    stop_check=lambda: j.stop_requested,
                    job_id=j.job_id,
                    resume_from_page=resume_page,
                    progress_callback=lambda page: setattr(j, "current_page", max(0, int(page or 0))),
                )
                j.result = result if isinstance(result, dict) else {"ok": True}
                if j.stop_requested or (
                    isinstance(result, dict) and result.get("status") == "cancelled"
                ):
                    j.status = "cancelled"
                else:
                    j.status = "done"
            except Exception as e:
                log.exception("taskq FAIL job=%s target=%s", job_id, j.target_code)
                j.error_msg = f"{type(e).__name__}: {e}"
                j.result = {
                    "error_type": type(e).__name__,
                    "error": str(e),
                    "traceback": traceback.format_exc(),
                }
                j.status = "failed"
            finally:
                j.finished_at = time.time()
                self._db_upsert_job(j, host=host)
                self._submit_rag_export_after_crawl(j)
                self._record_history(job_id)
                q.task_done()
                log.info("taskq END job=%s host=%s site=%s target=%s status=%s elapsed=%.1fs",
                         job_id, host, j.site_code, j.target_code, j.status,
                         (j.finished_at - (j.started_at or j.finished_at)))

    def _submit_rag_export_after_crawl(self, job: JobInfo) -> None:
        if job.status != "done":
            return
        try:
            from govcrawler.runtime_config import is_rag_export_after_crawl_enabled
            if not is_rag_export_after_crawl_enabled():
                return
        except Exception:
            log.exception("rag export config read failed job=%s target=%s", job.job_id, job.target_code)
            return
        asyncio.create_task(
            self._run_rag_export_for_job(job.job_id, job.target_code),
            name=f"rag-export-{job.job_id}",
        )

    async def _run_rag_export_for_job(self, job_id: str, target_code: str) -> None:
        job = self._jobs.get(job_id)
        if job is not None:
            job.result = job.result or {}
            job.result["rag_export"] = {"status": "running", "target_code": target_code}
            self._db_upsert_job(job, host=_host_for_site(job.site_code))
        try:
            result = await asyncio.to_thread(_export_target_to_rag, target_code)
            job = self._jobs.get(job_id)
            if job is not None:
                job.result = job.result or {}
                job.result["rag_export"] = {
                    "status": "completed" if result.get("failed", 0) == 0 else "partial_failed",
                    "target_code": target_code,
                    **result,
                }
                self._db_upsert_job(job, host=_host_for_site(job.site_code))
            log.info(
                "rag export after crawl done job=%s target=%s total=%s exported=%s failed=%s",
                job_id, target_code, result.get("total"), result.get("exported"), result.get("failed"),
            )
        except Exception as exc:
            log.exception("rag export after crawl failed job=%s target=%s", job_id, target_code)
            job = self._jobs.get(job_id)
            if job is not None:
                job.result = job.result or {}
                job.result["rag_export"] = {
                    "status": "failed",
                    "target_code": target_code,
                    "error": str(exc),
                }
                self._db_upsert_job(job, host=_host_for_site(job.site_code))

    def _record_history(self, job_id: str) -> None:
        self._history_order.append(job_id)
        if len(self._history_order) > HISTORY_KEEP:
            evict = self._history_order.pop(0)
            j = self._jobs.get(evict)
            if j and j.status in ("done", "failed", "cancelled"):
                self._jobs.pop(evict, None)


# Singleton — api registers startup hook to initialize.
_queue: TaskQueue | None = None


def get_queue() -> TaskQueue:
    global _queue
    if _queue is None:
        _queue = TaskQueue()
    return _queue


def _export_target_to_rag(target_code: str) -> dict:
    from govcrawler.rag.exporter import RagExporter

    exporter = RagExporter()
    try:
        result = asdict(exporter.export_pending(target_code=target_code, source="auto"))
        result["target_code"] = target_code
        return result
    finally:
        exporter.close()
