"""Per-site task queue for crawl jobs — prevents concurrent hits on a site.

Design:
  • A submit() call returns immediately with a job_id. The job is appended
    to the per-site FIFO; a single worker task per site drains it.
  • Concurrency across sites is unlimited; within one site it's strictly 1.
  • State is in-process memory. Lost on container restart — that's fine for
    v1 (jobs are idempotent; scheduler will re-trigger them next tick).
  • cancel() removes a queued job or sets the "stop requested" flag on a
    running one. Hard-kill of the sync crawl_target is intentionally not
    supported (would need a subprocess).

Call sites:
  • admin/targets.run_target  → submit()
  • scheduler cron fn         → submit()  (via HTTP to api; next step)
  • admin/jobs.* endpoints     → list() + cancel()
"""
from __future__ import annotations

import asyncio
import logging
import time
import traceback
import uuid
from dataclasses import asdict, dataclass, field
from datetime import datetime
from typing import Literal, Optional
from urllib.parse import urlparse

log = logging.getLogger(__name__)


# Cache: site_code → host (e.g. "gd_gkmlpt" → "www.gd.gov.cn"). Multiple
# site_codes can share the same host (e.g. gd_gkmlpt + gd_wjk both live on
# www.gd.gov.cn) and we want them to share a single FIFO so we never hit
# the host concurrently — same-host parallelism trips WAF rate limiters.
# Cache is invalidated lazily: a site row's base_url rarely changes; if it
# does, an api restart re-primes the cache.
_HOST_CACHE: dict[str, str] = {}


def _host_for_site(site_code: str) -> str:
    """Resolve site_code → host (lowercase netloc of base_url). Falls back
    to site_code itself if the lookup fails (unknown site, DB hiccup), so
    queue serialization degrades gracefully to per-site rather than
    silently colliding everything into one queue."""
    cached = _HOST_CACHE.get(site_code)
    if cached is not None:
        return cached
    try:
        from govcrawler.db import get_sessionmaker
        from govcrawler.models import CrawlSite
        S = get_sessionmaker()
        with S() as s:
            row = s.query(CrawlSite).filter_by(site_code=site_code).first()
            if row and row.base_url:
                host = urlparse(row.base_url).netloc.lower() or site_code
                _HOST_CACHE[site_code] = host
                return host
    except Exception:
        pass
    _HOST_CACHE[site_code] = site_code
    return site_code

JobStatus = Literal[
    "queued",       # sitting in the site's FIFO
    "running",      # worker picked it up
    "done",         # completed successfully
    "failed",       # crawl_target raised
    "cancelled",    # dequeued before running, or terminate-requested mid-run
]


@dataclass
class JobInfo:
    job_id: str
    site_code: str
    target_code: str
    source: str                             # "manual" | "schedule" | "retry"
    status: JobStatus = "queued"
    enqueued_at: float = field(default_factory=time.time)
    started_at: Optional[float] = None
    finished_at: Optional[float] = None
    error_msg: Optional[str] = None
    result: Optional[dict] = None
    # When true, pipeline runs with stop_on_duplicate=False — used for
    # backfill / 全量抓取 to ignore the historical-boundary early-stop and
    # walk every page even if all entries are dedup-skipped. Default False
    # preserves the increment-friendly behavior for cron and casual ▶.
    force: bool = False
    # Initial checkpoint used when a new manual job is explicitly queued to
    # continue from a previous run. Once persisted, crawl_target reads the
    # durable crawl_job.last_completed_page value before starting.
    initial_last_completed_page: int = 0
    # Live progress page. This is informational only; resume still uses
    # last_completed_page from the durable row.
    current_page: int = 0
    # Set to True when operator requests termination; pipeline code can poll
    # this from queue.job_for(job_id).stop_requested if it grows cancellation
    # hooks. For v1 the flag just marks the job as "cancelled" once the
    # current sync run returns.
    stop_requested: bool = False

    def to_dict(self) -> dict:
        now = time.time()
        running_for = None
        if self.started_at and not self.finished_at:
            running_for = now - self.started_at
        elif self.started_at and self.finished_at:
            running_for = self.finished_at - self.started_at
        return {
            "job_id": self.job_id,
            "site_code": self.site_code,
            "target_code": self.target_code,
            "source": self.source,
            "status": self.status,
            "stop_requested": self.stop_requested,
            "enqueued_at": datetime.utcfromtimestamp(self.enqueued_at).isoformat() if self.enqueued_at else None,
            "started_at": datetime.utcfromtimestamp(self.started_at).isoformat() if self.started_at else None,
            "finished_at": datetime.utcfromtimestamp(self.finished_at).isoformat() if self.finished_at else None,
            "running_for_sec": round(running_for, 1) if running_for else None,
            "queued_for_sec": round(
                (self.started_at or now) - self.enqueued_at, 1
            ) if self.enqueued_at else None,
            "error_msg": self.error_msg,
            "result": self.result,
            "initial_last_completed_page": self.initial_last_completed_page,
            "current_page": self.current_page,
        }


# Cap the number of "done/failed/cancelled" jobs we keep around for the UI
# history pane — prevents unbounded memory growth on a long-running api.
HISTORY_KEEP = 200


class TaskQueue:
    def __init__(self) -> None:
        self._jobs: dict[str, JobInfo] = {}
        # Queues keyed by HOST (netloc), not site_code — so all sites
        # sharing a base_url (gd_gkmlpt + gd_wjk on www.gd.gov.cn) serialize
        # through one FIFO. Same-host parallel hits trip WAF rate limiters.
        self._host_queues: dict[str, asyncio.Queue[str]] = {}
        self._workers: dict[str, asyncio.Task] = {}
        self._history_order: list[str] = []  # job_ids in completion order (FIFO trim)
        self._lock = asyncio.Lock()

    # ---------- DB persistence helpers ----------

    @staticmethod
    def _db_upsert_job(j: "JobInfo", *, host: str) -> None:
        """Mirror an in-memory JobInfo to the crawl_job DB row.

        Called on every state transition (queued → running → done/failed/
        cancelled). Best-effort: a DB error here must NOT take down the
        in-memory queue, so we swallow + log. The row may temporarily be
        out of sync until the next transition fixes it."""
        try:
            from datetime import datetime
            from govcrawler.db import get_sessionmaker
            from govcrawler.models import CrawlJob
            S = get_sessionmaker()
            with S() as s:
                row = s.get(CrawlJob, j.job_id)
                if row is None:
                    row = CrawlJob(
                        job_id=j.job_id,
                        host=host,
                        site_code=j.site_code,
                        target_code=j.target_code,
                        source=j.source,
                        status=j.status,
                        force=j.force,
                        stop_requested=j.stop_requested,
                        last_completed_page=max(0, int(j.initial_last_completed_page or 0)),
                        current_page=max(0, int(j.current_page or 0)),
                        attempt_count=0,
                        enqueued_at=datetime.utcfromtimestamp(j.enqueued_at),
                    )
                    s.add(row)
                else:
                    row.status = j.status
                    row.stop_requested = j.stop_requested
                    row.current_page = max(0, int(j.current_page or 0))
                row.started_at = (
                    datetime.utcfromtimestamp(j.started_at) if j.started_at else None
                )
                row.finished_at = (
                    datetime.utcfromtimestamp(j.finished_at) if j.finished_at else None
                )
                row.error_msg = j.error_msg
                row.result_json = j.result
                s.commit()
        except Exception:
            log.exception("crawl_job upsert failed job=%s", j.job_id)

    async def restore_from_db(self) -> dict[str, int]:
        """Boot-time recovery: any 'queued' or 'running' row whose process
        died is re-enqueued into the in-memory FIFO so a fresh worker
        drains it. The crawl is idempotent (dedup early-stop handles the
        already-fetched URLs in zero time), so resuming is safer than
        marking failed — operator's manual ▶ shouldn't get burned by a
        deploy that happens to overlap.

        Only jobs that exceeded MAX_RESTART_RECOVERY (3) get marked failed
        permanently — that bounds the loop in case some job keeps killing
        the api process repeatedly.

        Idempotent — safe to call once at api startup. Returns counts so
        the startup logger can show what was restored."""
        from datetime import datetime
        from govcrawler.db import get_sessionmaker
        from govcrawler.models import CrawlJob

        MAX_RESTART_RECOVERY = 3

        recovered = 0
        requeued = 0
        permanently_failed = 0

        S = get_sessionmaker()
        with S() as s:
            # 1. running rows are orphans — their process is gone. Bump
            # attempt_count and recover; if that crosses MAX, give up.
            for row in s.query(CrawlJob).filter_by(status="running").all():
                if (row.attempt_count or 0) + 1 >= MAX_RESTART_RECOVERY:
                    row.status = "failed"
                    row.finished_at = datetime.utcnow()
                    row.error_msg = (row.error_msg or "") + (
                        " | "
                        if row.error_msg else ""
                    ) + (
                        f"abandoned after {MAX_RESTART_RECOVERY} restart "
                        "recoveries (job kept crashing the api?)"
                    )
                    permanently_failed += 1
                    continue
                row.status = "queued"
                row.attempt_count = (row.attempt_count or 0) + 1
                # Clear started_at so the next run starts a fresh stopwatch.
                row.started_at = None
                # Annotate (not overwrite) error_msg so audit trail
                # preserves prior failures.
                annotation = f"restart_during_run (recovery attempt {row.attempt_count}/{MAX_RESTART_RECOVERY})"
                row.error_msg = (
                    f"{row.error_msg} | {annotation}"
                    if row.error_msg else annotation
                )
                recovered += 1
            s.commit()

            # 2. Re-enqueue all queued jobs (including those we just
            # converted from running). Original FIFO order.
            queued_rows = (
                s.query(CrawlJob)
                .filter_by(status="queued")
                .order_by(CrawlJob.enqueued_at.asc())
                .all()
            )
            for row in queued_rows:
                ji = JobInfo(
                    job_id=row.job_id,
                    site_code=row.site_code,
                    target_code=row.target_code,
                    source=row.source,
                    force=bool(row.force),
                    stop_requested=bool(row.stop_requested),
                    current_page=max(0, int(row.current_page or 0)),
                    enqueued_at=row.enqueued_at.timestamp() if row.enqueued_at else time.time(),
                )
                self._jobs[row.job_id] = ji
                host = row.host or _host_for_site(row.site_code)
                async with self._lock:
                    q = self._host_queues.get(host)
                    if q is None:
                        q = asyncio.Queue()
                        self._host_queues[host] = q
                        self._workers[host] = asyncio.create_task(
                            self._worker_loop(host, q),
                            name=f"taskq-{host}",
                        )
                await q.put(row.job_id)
                requeued += 1

        log.info(
            "task_queue restore: running→requeued=%d, queued_requeued=%d, "
            "permanently_failed=%d",
            recovered, requeued - recovered, permanently_failed,
        )
        return {
            "recovered": recovered,
            "requeued": requeued,
            "permanently_failed": permanently_failed,
        }

    # ---------- public API ----------

    async def submit(
        self, *, site_code: str, target_code: str, source: str = "manual",
        force: bool = False, resume_from_page: int = 0,
    ) -> str:
        """Enqueue a crawl job. Returns job_id. Spawns the per-host worker if
        this is the first job for that host. Persists to crawl_job table
        so a container restart doesn't lose the job."""
        host = _host_for_site(site_code)
        async with self._lock:
            job = JobInfo(
                job_id=uuid.uuid4().hex[:12],
                site_code=site_code,
                target_code=target_code,
                source=source,
                force=force,
                initial_last_completed_page=max(0, int(resume_from_page or 0)),
            )
            self._jobs[job.job_id] = job
            q = self._host_queues.get(host)
            if q is None:
                q = asyncio.Queue()
                self._host_queues[host] = q
                self._workers[host] = asyncio.create_task(
                    self._worker_loop(host, q),
                    name=f"taskq-{host}",
                )
            await q.put(job.job_id)
        # Persist AFTER the in-memory enqueue. If DB write fails the
        # job still runs (we'd rather lose durability than lose the run);
        # the next status transition will retry the upsert.
        self._db_upsert_job(job, host=host)
        log.info(
            "queue submit job=%s host=%s site=%s target=%s source=%s resume_from_page=%s",
            job.job_id, host, site_code, target_code, source, job.initial_last_completed_page,
        )
        return job.job_id

    def job(self, job_id: str) -> JobInfo | None:
        return self._jobs.get(job_id)

    def list_jobs(
        self, *, site: str | None = None, status: str | None = None,
        include_history: bool = True, limit: int = 200,
    ) -> list[JobInfo]:
        out: list[JobInfo] = []
        for j in self._jobs.values():
            if site and j.site_code != site:
                continue
            if status and j.status != status:
                continue
            if not include_history and j.status in ("done", "failed", "cancelled"):
                continue
            out.append(j)
        # Sort: running first, then queued (by enqueued_at asc), then history (desc finished_at)
        def sort_key(j: JobInfo):
            rank = {"running": 0, "queued": 1, "cancelled": 2, "failed": 2, "done": 3}.get(j.status, 9)
            ts = j.finished_at or j.started_at or j.enqueued_at
            # queued → ascending (older first); others → descending (newest first)
            if j.status == "queued":
                return (rank, ts)
            return (rank, -ts)
        out.sort(key=sort_key)
        return out[:limit]

    def queue_summary(self) -> list[dict]:
        """One row per active host. With host-based queueing, multiple
        site_codes (e.g. gd_gkmlpt, gd_wjk) sharing www.gd.gov.cn appear in
        the same row's site_codes list."""
        out = []
        for host, q in self._host_queues.items():
            host_jobs = [
                j for j in self._jobs.values()
                if _host_for_site(j.site_code) == host
            ]
            running = next((j for j in host_jobs if j.status == "running"), None)
            queued = [j for j in host_jobs if j.status == "queued"]
            site_codes = sorted({j.site_code for j in host_jobs})
            out.append({
                "host": host,
                "site_codes": site_codes,
                # Back-compat: UI may still reference site_code; expose the
                # running job's site_code (or first known on this host).
                "site_code": running.site_code if running else (site_codes[0] if site_codes else host),
                "running_job_id": running.job_id if running else None,
                "running_target": running.target_code if running else None,
                "running_site": running.site_code if running else None,
                "running_for_sec": round(time.time() - running.started_at, 1)
                    if running and running.started_at else None,
                "queued_count": len(queued),
                "next_queued_target": queued[0].target_code if queued else None,
                "next_queued_site": queued[0].site_code if queued else None,
            })
        return out

    async def cancel(self, job_id: str) -> dict:
        """Cancel a queued job (remove from queue) or flag a running one.

        For queued jobs: flip status → cancelled; worker will skip when it pops.
        For running jobs: set stop_requested = True. The current crawl run
        will continue (sync function), but status will transition to
        cancelled on completion instead of done.
        """
        j = self._jobs.get(job_id)
        if j is None:
            return {"ok": False, "reason": "not_found"}
        if j.status == "queued":
            j.status = "cancelled"
            j.finished_at = time.time()
            self._db_upsert_job(j, host=_host_for_site(j.site_code))
            log.info("queue cancel queued job=%s site=%s", job_id, j.site_code)
            return {"ok": True, "cancelled": "queued"}
        if j.status == "running":
            j.stop_requested = True
            self._db_upsert_job(j, host=_host_for_site(j.site_code))
            log.info("queue cancel running (flag) job=%s site=%s", job_id, j.site_code)
            return {"ok": True, "cancelled": "running_flagged"}
        return {"ok": False, "reason": f"cannot cancel status={j.status}"}

    # ---------- internals ----------

    async def _worker_loop(self, host: str, q: asyncio.Queue[str]) -> None:
        log.info("taskq worker started host=%s", host)
        while True:
            job_id = await q.get()
            j = self._jobs.get(job_id)
            if j is None or j.status == "cancelled":
                q.task_done()
                continue
            j.status = "running"
            j.started_at = time.time()
            self._db_upsert_job(j, host=host)
            log.info("taskq START job=%s host=%s site=%s target=%s (queued for %.1fs)",
                     job_id, host, j.site_code, j.target_code,
                     j.started_at - j.enqueued_at)
            try:
                # Run the (sync) crawl in a thread so we don't block the
                # event loop. Pass a thread-safe stop_check closure that
                # peeks at j.stop_requested — pipeline polls it between
                # list pages and between articles so /cancel actually
                # halts the in-flight crawl mid-flight, not just labels
                # the row 'cancelled' after it finishes naturally.
                from govcrawler.pipeline import crawl_target

                # Resume from saved page checkpoint if the target opted in.
                # crawl_target reads track_checkpoint flag itself; we just
                # supply the last completed page from the durable row.
                resume_page = 0
                try:
                    from govcrawler.db import get_sessionmaker as _S
                    from govcrawler.models import CrawlJob as _CJ
                    with _S()() as _s:
                        _row = _s.get(_CJ, j.job_id)
                        if _row is not None:
                            resume_page = int(_row.last_completed_page or 0)
                except Exception:
                    pass

                result = await asyncio.to_thread(
                    crawl_target,
                    j.target_code,
                    stop_on_duplicate=not j.force,
                    stop_check=lambda: j.stop_requested,
                    job_id=j.job_id,
                    resume_from_page=resume_page,
                    progress_callback=lambda page: setattr(j, "current_page", max(0, int(page or 0))),
                )
                j.result = result if isinstance(result, dict) else {"ok": True}
                if j.stop_requested or (
                    isinstance(result, dict) and result.get("status") == "cancelled"
                ):
                    j.status = "cancelled"
                else:
                    j.status = "done"
            except Exception as e:
                log.exception("taskq FAIL job=%s target=%s", job_id, j.target_code)
                j.error_msg = f"{type(e).__name__}: {e}"
                j.result = {
                    "error_type": type(e).__name__,
                    "error": str(e),
                    "traceback": traceback.format_exc(),
                }
                j.status = "failed"
            finally:
                j.finished_at = time.time()
                self._db_upsert_job(j, host=host)
                self._submit_rag_export_after_crawl(j)
                self._record_history(job_id)
                q.task_done()
                log.info("taskq END job=%s host=%s site=%s target=%s status=%s elapsed=%.1fs",
                         job_id, host, j.site_code, j.target_code, j.status,
                         (j.finished_at - (j.started_at or j.finished_at)))

    def _submit_rag_export_after_crawl(self, job: JobInfo) -> None:
        if job.status != "done":
            return
        try:
            from govcrawler.settings import get_settings
            if not get_settings().rag_export_after_crawl_enabled:
                return
        except Exception:
            log.exception("rag export config read failed job=%s target=%s", job.job_id, job.target_code)
            return
        asyncio.create_task(
            self._run_rag_export_for_job(job.job_id, job.target_code),
            name=f"rag-export-{job.job_id}",
        )

    async def _run_rag_export_for_job(self, job_id: str, target_code: str) -> None:
        job = self._jobs.get(job_id)
        if job is not None:
            job.result = job.result or {}
            job.result["rag_export"] = {"status": "running", "target_code": target_code}
            self._db_upsert_job(job, host=_host_for_site(job.site_code))
        try:
            result = await asyncio.to_thread(_export_target_to_rag, target_code)
            job = self._jobs.get(job_id)
            if job is not None:
                job.result = job.result or {}
                job.result["rag_export"] = {
                    "status": "completed" if result.get("failed", 0) == 0 else "partial_failed",
                    "target_code": target_code,
                    **result,
                }
                self._db_upsert_job(job, host=_host_for_site(job.site_code))
            log.info(
                "rag export after crawl done job=%s target=%s total=%s exported=%s failed=%s",
                job_id, target_code, result.get("total"), result.get("exported"), result.get("failed"),
            )
        except Exception as exc:
            log.exception("rag export after crawl failed job=%s target=%s", job_id, target_code)
            job = self._jobs.get(job_id)
            if job is not None:
                job.result = job.result or {}
                job.result["rag_export"] = {
                    "status": "failed",
                    "target_code": target_code,
                    "error": str(exc),
                }
                self._db_upsert_job(job, host=_host_for_site(job.site_code))

    def _record_history(self, job_id: str) -> None:
        self._history_order.append(job_id)
        if len(self._history_order) > HISTORY_KEEP:
            evict = self._history_order.pop(0)
            j = self._jobs.get(evict)
            if j and j.status in ("done", "failed", "cancelled"):
                self._jobs.pop(evict, None)


# Singleton — api registers startup hook to initialize.
_queue: TaskQueue | None = None


def get_queue() -> TaskQueue:
    global _queue
    if _queue is None:
        _queue = TaskQueue()
    return _queue


def _export_target_to_rag(target_code: str) -> dict:
    from govcrawler.rag.exporter import RagExporter

    exporter = RagExporter()
    try:
        result = asdict(exporter.export_pending(target_code=target_code))
        result["target_code"] = target_code
        return result
    finally:
        exporter.close()
