from __future__ import annotations

import json
import mimetypes
import re
import time
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from pathlib import Path, PurePosixPath
from typing import Any

import httpx
from sqlalchemy import and_, or_, select, update
from sqlalchemy.orm import Session, selectinload

from govcrawler.db import get_sessionmaker
from govcrawler.models import Article, ArticleRagPushLog, Attachment, CrawlTarget
from govcrawler.settings import Settings, get_settings
from govcrawler.storage.paths import to_os_path

KNOWLEDGE_CATEGORY = "信息采集"


@dataclass
class ExportedDocument:
    doc_id: str
    kind: str
    title: str
    filename: str
    status: str
    task_id: str | None = None
    error: str | None = None


@dataclass
class ArticleExportResult:
    article_id: int
    title: str
    status: str
    documents: list[ExportedDocument] = field(default_factory=list)
    error: str | None = None


@dataclass
class BatchExportResult:
    total: int
    exported: int
    failed: int
    dry_run: bool
    items: list[ArticleExportResult]


def article_doc_id(article_id: int) -> str:
    """Stable zm-rag doc_id for the article body.

    Keep this Windows-filename-safe because zm-rag currently uses doc_id when
    creating a temporary upload path.
    """
    return f"govcrawler_article_{article_id}"


def attachment_doc_id(article_id: int, attachment_id: int) -> str:
    return f"govcrawler_article_{article_id}_attachment_{attachment_id}"


def split_subject_words(raw: str | None) -> list[str]:
    if not raw:
        return []
    parts = re.split(r"[、,，;；\s]+", raw)
    return list(dict.fromkeys(p.strip() for p in parts if p.strip()))


def _date_str(value: Any) -> str:
    if value is None:
        return ""
    if hasattr(value, "isoformat"):
        return value.isoformat()[:10]
    return str(value)[:10]


def _source_fields(article: Article) -> dict[str, Any]:
    site_code = article.site.site_code if article.site is not None else ""
    target_code = article.target.target_code if article.target is not None else ""
    return {
        "source_system": "GovCrawler",
        "source_article_id": str(article.id),
        "source_site_code": site_code,
        "source_target_code": target_code,
        "source_url": article.url,
        "channel_name": article.channel_name or "",
        "channel_path": article.channel_path or "",
        "source_metadata": {
            "url_hash": article.url_hash,
            "native_post_id": article.native_post_id,
            "source_raw": article.source_raw,
            "index_no": article.index_no,
            "open_category": article.open_category,
            "content_category": article.content_category,
            "content_subcategory": article.content_subcategory,
            "effective_date": _date_str(getattr(article, "effective_date", None)),
            "is_effective": getattr(article, "is_effective", None),
            "expiry_date": _date_str(getattr(article, "expiry_date", None)),
            "metadata_json": article.metadata_json or {},
        },
    }


def _attachment_related_docs(article: Article) -> list[dict[str, str]]:
    related: list[dict[str, str]] = []
    for attachment in getattr(article, "attachments", []) or []:
        attachment_id = getattr(attachment, "id", None)
        if attachment_id is None:
            continue
        title = (
            getattr(attachment, "file_name", None)
            or getattr(attachment, "original_filename", None)
            or f"attachment {attachment_id}"
        )
        related.append({
            "doc_id": attachment_doc_id(article.id, attachment_id),
            "title": title,
            "relation_type": "\u9644\u4ef6",
        })
    return related


def _article_related_doc(article: Article) -> dict[str, str]:
    title = article.title or f"\u6587\u7ae0 {article.id}"
    return {
        "doc_id": article_doc_id(article.id),
        "title": title,
        "relation_type": "\u6b63\u6587",
    }


def build_article_metadata(article: Article, *, filename: str) -> dict[str, Any]:
    title = article.title or Path(filename).stem
    return {
        "doc_id": article_doc_id(article.id),
        "title": title,
        "original_filename": filename,
        "doc_number": article.doc_no or "",
        "issuing_org": article.publisher or article.source_raw or "",
        "doc_type": "网页正文",
        "subject_words": split_subject_words(article.topic_words),
        "publish_date": _date_str(article.publish_date or article.publish_time),
        "effective_date": _date_str(getattr(article, "effective_date", None)),
        "is_effective": getattr(article, "is_effective", None),
        "expiry_date": _date_str(getattr(article, "expiry_date", None)),
        "knowledge_category": KNOWLEDGE_CATEGORY,
        "acl_ids": [],
        "related_docs": _attachment_related_docs(article),
        **_source_fields(article),
    }


def build_attachment_metadata(
    article: Article,
    attachment: Attachment,
    *,
    filename: str,
) -> dict[str, Any]:
    base_title = article.title or f"文章 {article.id}"
    title = f"{base_title} - 附件：{attachment.file_name or filename}"
    source_fields = _source_fields(article)
    source_fields["source_attachment_id"] = str(attachment.id)
    attachment_source_url = getattr(attachment, "source_url", None)
    if attachment_source_url:
        source_fields["source_url"] = attachment_source_url
    source_fields["source_metadata"] = {
        **source_fields["source_metadata"],
        "article_source_url": article.url,
        "attachment_source_url": attachment_source_url or "",
        "attachment_file_hash": attachment.file_hash,
        "attachment_file_ext": attachment.file_ext,
        "attachment_size_bytes": attachment.size_bytes,
    }
    return {
        "doc_id": attachment_doc_id(article.id, attachment.id),
        "title": title,
        "original_filename": filename,
        "doc_number": article.doc_no or "",
        "issuing_org": article.publisher or article.source_raw or "",
        "doc_type": "网页附件",
        "subject_words": split_subject_words(article.topic_words),
        "publish_date": _date_str(article.publish_date or article.publish_time),
        "effective_date": _date_str(getattr(article, "effective_date", None)),
        "is_effective": getattr(article, "is_effective", None),
        "expiry_date": _date_str(getattr(article, "expiry_date", None)),
        "knowledge_category": KNOWLEDGE_CATEGORY,
        "acl_ids": [],
        "related_docs": [_article_related_doc(article)],
        **source_fields,
    }


class RagExportError(RuntimeError):
    pass


class RagIngestClient:
    def __init__(self, settings: Settings | None = None):
        self._settings = settings or get_settings()
        self._client = httpx.Client(timeout=self._settings.rag_export_timeout_s)

    def close(self) -> None:
        self._client.close()

    def _headers(self) -> dict[str, str]:
        if not self._settings.rag_ingest_token:
            return {}
        return {"Authorization": f"Bearer {self._settings.rag_ingest_token}"}

    def _status_url(self, task_id: str) -> str:
        if self._settings.rag_status_url:
            return self._settings.rag_status_url.format(task_id=task_id)
        ingest_url = self._settings.rag_ingest_url
        if "/webhook/document" in ingest_url:
            return ingest_url.replace("/webhook/document", f"/webhook/status/{task_id}")
        return f"{ingest_url.rstrip('/')}/status/{task_id}"

    def ingest_file(self, file_path: Path, metadata: dict[str, Any]) -> dict[str, Any]:
        filename = metadata.get("original_filename") or file_path.name
        media_type = mimetypes.guess_type(filename)[0] or "application/octet-stream"
        with file_path.open("rb") as fh:
            response = self._client.post(
                self._settings.rag_ingest_url,
                headers=self._headers(),
                data={"metadata": json.dumps(metadata, ensure_ascii=False)},
                files={"file": (filename, fh, media_type)},
            )
        try:
            response.raise_for_status()
        except httpx.HTTPStatusError as exc:
            raise RagExportError(
                f"zm-rag ingest failed ({response.status_code}): {response.text[:500]}"
            ) from exc
        try:
            return response.json()
        except ValueError as exc:
            raise RagExportError("zm-rag ingest returned non-JSON response") from exc

    def get_task_status(self, task_id: str) -> dict[str, Any]:
        response = self._client.get(
            self._status_url(task_id),
            headers=self._headers(),
        )
        try:
            response.raise_for_status()
        except httpx.HTTPStatusError as exc:
            raise RagExportError(
                f"zm-rag status check failed ({response.status_code}): {response.text[:500]}"
            ) from exc
        try:
            return response.json()
        except ValueError as exc:
            raise RagExportError("zm-rag status endpoint returned non-JSON response") from exc

    def wait_for_task(self, task_id: str) -> dict[str, Any]:
        deadline = time.monotonic() + self._settings.rag_export_poll_timeout_s
        terminal_ok = {"COMPLETED", "PARTIAL_FAILED"}
        terminal_failed = {"FAILED", "REVOKED"}

        while True:
            payload = self.get_task_status(task_id)
            status = str(payload.get("status", "")).upper()
            if status in terminal_ok:
                return payload
            if status in terminal_failed:
                raise RagExportError(
                    f"zm-rag ingest task {task_id} ended with {status}: {payload.get('error') or ''}"
                )
            if time.monotonic() >= deadline:
                raise RagExportError(
                    f"zm-rag ingest task {task_id} did not complete within "
                    f"{self._settings.rag_export_poll_timeout_s:.0f}s; last status={status or 'unknown'}"
                )
            time.sleep(self._settings.rag_export_poll_interval_s)


class RagExporter:
    def __init__(
        self,
        *,
        settings: Settings | None = None,
        client: RagIngestClient | None = None,
    ):
        self._settings = settings or get_settings()
        self._client = client or RagIngestClient(self._settings)
        self._owns_client = client is None

    def close(self) -> None:
        if self._owns_client:
            self._client.close()

    def export_pending(
        self,
        *,
        limit: int | None = None,
        article_id: int | None = None,
        target_code: str | None = None,
        dry_run: bool = False,
        source: str = "manual",
    ) -> BatchExportResult:
        SessionMaker = get_sessionmaker()
        with SessionMaker() as session:
            if not (self._settings.rag_ingest_url or "").strip():
                if dry_run:
                    articles = self._load_articles(
                        session,
                        limit=limit,
                        article_id=article_id,
                        target_code=target_code,
                    )
                    return BatchExportResult(
                        total=len(articles),
                        exported=0,
                        failed=0,
                        dry_run=True,
                        items=[
                            ArticleExportResult(
                                article_id=article.id,
                                title=article.title or "",
                                status="dry_run",
                            )
                            for article in articles
                        ],
                    )
                articles = self._claim_articles(
                    session,
                    limit=limit,
                    article_id=article_id,
                    target_code=target_code,
                )
                items = [
                    self._record_disabled_export(session, article, source=source)
                    for article in articles
                ]
                session.commit()
                return BatchExportResult(
                    total=len(items),
                    exported=0,
                    failed=len(items),
                    dry_run=False,
                    items=items,
                )
            if dry_run:
                articles = self._load_articles(
                    session,
                    limit=limit,
                    article_id=article_id,
                    target_code=target_code,
                )
            else:
                articles = self._claim_articles(
                    session,
                    limit=limit,
                    article_id=article_id,
                    target_code=target_code,
                )
            items: list[ArticleExportResult] = []
            exported = 0
            failed = 0
            for article in articles:
                result = self.export_article(
                    session,
                    article,
                    dry_run=dry_run,
                    source=source,
                )
                items.append(result)
                if result.status == "exported":
                    exported += 1
                elif result.status == "failed":
                    failed += 1
            return BatchExportResult(
                total=len(articles),
                exported=exported,
                failed=failed,
                dry_run=dry_run,
                items=items,
            )

    def export_article(
        self,
        session: Session,
        article: Article,
        *,
        dry_run: bool = False,
        source: str = "manual",
    ) -> ArticleExportResult:
        result = ArticleExportResult(
            article_id=article.id,
            title=article.title or "",
            status="dry_run" if dry_run else "exported",
        )
        started_at = datetime.utcnow()
        started_monotonic = time.monotonic()
        log_id: int | None = None
        if not dry_run:
            push_log = ArticleRagPushLog(
                article_id=article.id,
                source=source if source in {"manual", "auto"} else "manual",
                status="running",
                file_count=0,
                started_at=started_at,
                rag_ingest_url=self._settings.rag_ingest_url,
            )
            session.add(push_log)
            session.commit()
            log_id = push_log.id
        try:
            body_path = self._resolve_data_file(article.text_path)
            body_filename = body_path.name
            body_meta = build_article_metadata(article, filename=body_filename)
            result.documents.append(
                self._export_document(
                    file_path=body_path,
                    metadata=body_meta,
                    kind="article",
                    dry_run=dry_run,
                )
            )

            for attachment in article.attachments:
                attachment_path = self._resolve_data_file(attachment.file_path)
                attachment_filename = attachment.file_name or attachment_path.name
                attachment_meta = build_attachment_metadata(
                    article,
                    attachment,
                    filename=attachment_filename,
                )
                result.documents.append(
                    self._export_document(
                        file_path=attachment_path,
                        metadata=attachment_meta,
                        kind="attachment",
                        dry_run=dry_run,
                    )
                )

            document_error = self._document_failure_error(result.documents)
            if document_error:
                raise RagExportError(document_error)

            if not dry_run:
                now = datetime.utcnow()
                duration_ms = int((time.monotonic() - started_monotonic) * 1000)
                task_ids = [doc.task_id for doc in result.documents if doc.task_id]
                article.exported_to_rag_at = now
                article.rag_export_status = "completed"
                article.rag_export_finished_at = now
                article.rag_export_error = None
                article.rag_export_task_ids = task_ids
                if log_id is not None:
                    push_log = session.get(ArticleRagPushLog, log_id)
                    if push_log is not None:
                        push_log.status = "completed"
                        push_log.file_count = len(result.documents)
                        push_log.duration_ms = duration_ms
                        push_log.finished_at = now
                        push_log.error_msg = None
                        push_log.task_ids = task_ids
                session.commit()
        except Exception as exc:
            session.rollback()
            if not dry_run:
                now = datetime.utcnow()
                duration_ms = int((time.monotonic() - started_monotonic) * 1000)
                task_ids = [doc.task_id for doc in result.documents if doc.task_id]
                failed_article = session.get(Article, article.id)
                if failed_article is not None:
                    failed_article.rag_export_status = "failed"
                    failed_article.rag_export_finished_at = now
                    failed_article.rag_export_error = str(exc)[:4000]
                    failed_article.rag_export_task_ids = task_ids
                if log_id is not None:
                    push_log = session.get(ArticleRagPushLog, log_id)
                    if push_log is not None:
                        push_log.status = "failed"
                        push_log.file_count = len(result.documents)
                        push_log.duration_ms = duration_ms
                        push_log.finished_at = now
                        push_log.error_msg = str(exc)[:4000]
                        push_log.task_ids = task_ids
                session.commit()
            result.status = "failed"
            result.error = str(exc)
        return result

    def _record_disabled_export(
        self,
        session: Session,
        article: Article,
        *,
        source: str,
    ) -> ArticleExportResult:
        now = datetime.utcnow()
        error = "RAG_INGEST_URL is empty; skip RAG export"
        normalized_source = source if source in {"manual", "auto"} else "manual"
        file_count = 1 + len(getattr(article, "attachments", []) or [])
        article.rag_export_status = "failed"
        article.rag_export_started_at = now
        article.rag_export_finished_at = now
        article.rag_export_error = error
        article.rag_export_task_ids = []
        session.add(
            ArticleRagPushLog(
                article_id=article.id,
                source=normalized_source,
                status="failed",
                file_count=file_count,
                duration_ms=0,
                started_at=now,
                finished_at=now,
                error_msg=error,
                task_ids=[],
                rag_ingest_url=self._settings.rag_ingest_url,
            )
        )
        return ArticleExportResult(
            article_id=article.id,
            title=article.title or "",
            status="failed",
            error=error,
        )

    @staticmethod
    def _document_failure_error(documents: list[ExportedDocument]) -> str | None:
        failures: list[str] = []
        for doc in documents:
            status = (doc.status or "").upper()
            if status in {"FAILED", "PARTIAL_FAILED", "REVOKED"} or doc.error:
                detail = doc.error or f"status={doc.status or 'unknown'}"
                failures.append(f"{doc.kind} {doc.doc_id} {detail}")
        if not failures:
            return None
        return "zm-rag ingest returned unsuccessful document status: " + "; ".join(failures)

    def _export_document(
        self,
        *,
        file_path: Path,
        metadata: dict[str, Any],
        kind: str,
        dry_run: bool,
    ) -> ExportedDocument:
        if dry_run:
            return ExportedDocument(
                doc_id=metadata["doc_id"],
                kind=kind,
                title=metadata.get("title", ""),
                filename=metadata.get("original_filename", file_path.name),
                status="dry_run",
            )
        payload = self._client.ingest_file(file_path, metadata)
        task_id = payload.get("task_id")
        if self._settings.rag_export_wait_completion:
            if not task_id:
                raise RagExportError("zm-rag ingest response did not include task_id")
            payload = self._client.wait_for_task(task_id)
        return ExportedDocument(
            doc_id=metadata["doc_id"],
            kind=kind,
            title=metadata.get("title", ""),
            filename=metadata.get("original_filename", file_path.name),
            status=payload.get("status", "queued"),
            task_id=task_id or payload.get("task_id"),
            error=payload.get("error"),
        )

    @staticmethod
    def _claimable_filter(stale_before: datetime | None = None):
        conditions = [
            Article.rag_export_status.is_(None),
            Article.rag_export_status.in_(("pending", "failed")),
        ]
        if stale_before is not None:
            conditions.append(
                and_(
                    Article.rag_export_status == "running",
                    Article.rag_export_started_at.is_not(None),
                    Article.rag_export_started_at < stale_before,
                )
            )
        return or_(*conditions)

    def _claim_articles(
        self,
        session: Session,
        *,
        limit: int | None,
        article_id: int | None,
        target_code: str | None,
    ) -> list[Article]:
        now = datetime.utcnow()
        stale_before = now - timedelta(seconds=self._settings.rag_export_running_stale_s)
        base_query = (
            select(Article.id)
            .where(
                Article.status == "ready",
                Article.exported_to_rag_at.is_(None),
                self._claimable_filter(stale_before),
            )
            .order_by(Article.fetched_at.asc(), Article.id.asc())
            .limit(limit or self._settings.rag_export_batch_size)
        )
        if target_code is not None:
            base_query = (
                base_query
                .join(CrawlTarget, CrawlTarget.id == Article.target_id)
                .where(CrawlTarget.target_code == target_code)
            )
        if article_id is not None:
            base_query = (
                select(Article.id)
                .where(
                    Article.id == article_id,
                    Article.status == "ready",
                    Article.exported_to_rag_at.is_(None),
                    self._claimable_filter(stale_before),
                )
                .limit(1)
            )
            if target_code is not None:
                base_query = (
                    base_query
                    .join(CrawlTarget, CrawlTarget.id == Article.target_id)
                    .where(CrawlTarget.target_code == target_code)
                )

        candidate_ids = list(session.execute(base_query).scalars().all())
        if not candidate_ids:
            return []

        claimed_ids: list[int] = []
        for candidate_id in candidate_ids:
            result = session.execute(
                update(Article)
                .where(
                    Article.id == candidate_id,
                    Article.status == "ready",
                    Article.exported_to_rag_at.is_(None),
                    self._claimable_filter(stale_before),
                )
                .values(
                    rag_export_status="running",
                    rag_export_started_at=now,
                    rag_export_finished_at=None,
                    rag_export_error=None,
                    rag_export_task_ids=[],
                )
            )
            if result.rowcount:
                claimed_ids.append(candidate_id)
        session.commit()

        if not claimed_ids:
            return []
        stale_before = (
            datetime.utcnow()
            - timedelta(seconds=self._settings.rag_export_running_stale_s)
        )
        query = (
            select(Article)
            .options(
                selectinload(Article.attachments),
                selectinload(Article.site),
                selectinload(Article.target),
            )
            .where(Article.id.in_(claimed_ids))
            .order_by(Article.fetched_at.asc(), Article.id.asc())
        )
        return list(session.execute(query).scalars().all())

    def _load_articles(
        self,
        session: Session,
        *,
        limit: int | None,
        article_id: int | None,
        target_code: str | None,
    ) -> list[Article]:
        stale_before = (
            datetime.utcnow()
            - timedelta(seconds=self._settings.rag_export_running_stale_s)
        )
        query = (
            select(Article)
            .options(
                selectinload(Article.attachments),
                selectinload(Article.site),
                selectinload(Article.target),
            )
            .where(Article.status == "ready")
        )
        if target_code is not None:
            query = (
                query
                .join(CrawlTarget, CrawlTarget.id == Article.target_id)
                .where(CrawlTarget.target_code == target_code)
            )
        if article_id is not None:
            query = query.where(
                Article.id == article_id,
                Article.exported_to_rag_at.is_(None),
                self._claimable_filter(stale_before),
            )
        else:
            query = query.where(
                Article.exported_to_rag_at.is_(None),
                self._claimable_filter(stale_before),
            )
        query = query.order_by(Article.fetched_at.asc(), Article.id.asc())
        query = query.limit(limit or self._settings.rag_export_batch_size)
        return list(session.execute(query).scalars().all())

    def _resolve_data_file(self, rel_path: str | None) -> Path:
        if not rel_path:
            raise FileNotFoundError("source file path is empty")
        data_dir = Path(self._settings.data_dir)
        abs_path = to_os_path(data_dir, PurePosixPath(rel_path))
        try:
            abs_path.resolve().relative_to(data_dir.resolve())
        except Exception as exc:
            raise ValueError(f"invalid source file path: {rel_path}") from exc
        if not abs_path.exists():
            raise FileNotFoundError(f"source file missing: {abs_path}")
        return abs_path


def export_pending_to_rag(
    *,
    limit: int | None = None,
    article_id: int | None = None,
    target_code: str | None = None,
    dry_run: bool = False,
    source: str = "manual",
) -> BatchExportResult:
    exporter = RagExporter()
    try:
        return exporter.export_pending(
            limit=limit,
            article_id=article_id,
            target_code=target_code,
            dry_run=dry_run,
            source=source,
        )
    finally:
        exporter.close()
