"""Bulk knowledge-graph construction script.

Usage
-----
# Build graphs for all ingested documents (2 workers):
python scripts/bulk_build_graph.py

# Only documents not yet processed:
python scripts/bulk_build_graph.py --skip-existing

# Override concurrency and limit:
python scripts/bulk_build_graph.py --concurrency 4 --limit 500

# Target specific doc IDs:
python scripts/bulk_build_graph.py --doc-ids id1,id2,id3

# Dry-run (list docs but don't build):
python scripts/bulk_build_graph.py --dry-run
"""

from __future__ import annotations

import argparse
import asyncio
import json
import sys
import time
from pathlib import Path
from typing import Any

# ── Bootstrap: make 'app' importable ──────────────────────────────────────────
sys.path.insert(0, str(Path(__file__).parent.parent))

from app.config import settings
from app.core.document_processor import DocumentProcessor
from app.core.graph_builder import GraphBuilder
from app.infrastructure.llm_client import LLMClient
from app.infrastructure.neo4j_client import Neo4jClient
from app.utils.logger import configure_logging, get_logger

configure_logging(debug=False)
logger = get_logger(__name__)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


async def scroll_all_docs(
    es: Any,
    *,
    size: int = 200,
    limit: int | None = None,
) -> list[dict[str, Any]]:
    """Scroll through gov_doc_meta and return all doc metadata records."""
    docs: list[dict[str, Any]] = []
    resp = await es.search(
        index=settings.es_meta_index,
        body={
            "query": {"term": {"status": "completed"}},
            "sort": [{"created_at": {"order": "asc"}}],
            "size": size,
            "_source": True,
        },
        scroll="5m",
    )
    scroll_id = resp.get("_scroll_id")

    while True:
        hits = resp["hits"]["hits"]
        if not hits:
            break
        for hit in hits:
            docs.append(hit["_source"])
            if limit and len(docs) >= limit:
                await es.clear_scroll(body={"scroll_id": [scroll_id]})
                return docs

        if not scroll_id:
            break
        resp = await es.scroll(body={"scroll_id": scroll_id}, scroll="5m")

    if scroll_id:
        try:
            await es.clear_scroll(body={"scroll_id": [scroll_id]})
        except Exception:
            pass

    return docs


async def has_graph(neo4j: Neo4jClient, doc_id: str) -> bool:
    """Return True if the Document node already has at least one relationship."""
    async with neo4j.driver.session(database=settings.neo4j_database) as session:
        result = await session.run(
            "MATCH (d:Document {doc_id: $doc_id})-[r]-() "
            "RETURN count(r) AS cnt LIMIT 1",
            doc_id=doc_id,
        )
        record = await result.single()
        return bool(record and record["cnt"] > 0)


async def get_text_for_doc(es: Any, doc_id: str, file_path: str | None) -> str:
    """Get document full text: try file first, fall back to ES chunk concat."""
    # Option A: re-extract from file
    if file_path and Path(file_path).is_file():
        try:
            processor = DocumentProcessor()
            result = processor.extract_text(file_path)
            text = result.get("full_text", "")
            if text.strip():
                return text
        except Exception as exc:
            logger.warning("file_extract_failed", doc_id=doc_id, error=str(exc))

    # Option B: concatenate chunks from Elasticsearch (by content_hash)
    content_hash = ""
    try:
        meta_resp = await es.get(
            index=settings.es_meta_index, id=doc_id, _source=["content_hash"]
        )
        raw = meta_resp if isinstance(meta_resp, dict) else meta_resp.body
        content_hash = raw.get("_source", {}).get("content_hash", "")
    except Exception:
        pass

    if not content_hash:
        logger.warning("no_content_hash", doc_id=doc_id)
        return ""

    resp = await es.search(
        index=settings.es_chunk_index,
        body={
            "query": {"term": {"content_hash": content_hash}},
            "sort": [{"chunk_index": {"order": "asc"}}],
            "size": 2000,
            "_source": ["content"],
        },
    )
    chunks = [hit["_source"].get("content", "") for hit in resp["hits"]["hits"]]
    return "\n".join(chunks)


# ---------------------------------------------------------------------------
# Worker coroutine
# ---------------------------------------------------------------------------


async def process_doc(
    doc: dict[str, Any],
    builder: GraphBuilder,
    es: Any,
    neo4j: Neo4jClient,
    *,
    skip_existing: bool,
    stats: dict[str, int],
) -> None:
    doc_id: str = doc.get("doc_id", "")
    if not doc_id:
        stats["skipped"] += 1
        return

    try:
        # Skip docs that already have graph data
        if skip_existing and await has_graph(neo4j, doc_id):
            logger.info("graph_exists_skip", doc_id=doc_id)
            stats["skipped"] += 1
            return

        file_path: str | None = doc.get("file_path")
        content = await get_text_for_doc(es, doc_id, file_path)

        if not content.strip():
            logger.warning("no_content_skip", doc_id=doc_id)
            stats["skipped"] += 1
            return

        scene_type = (
            doc.get("document_scene_type")
            or doc.get("knowledge_category_code")
            or ""
        )
        result = await builder.build_graph(doc_id, doc, content, scene_type=scene_type)

        if result.get("status") == "completed":
            stats["success"] += 1
            logger.info(
                "graph_built",
                doc_id=doc_id,
                entities=result.get("entity_count", 0),
                relations=result.get("relation_count", 0),
            )
        else:
            stats["failed"] += 1
            logger.error("graph_failed", doc_id=doc_id, error=result.get("error"))

    except Exception as exc:
        stats["failed"] += 1
        logger.error("graph_exception", doc_id=doc_id, error=str(exc))


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------


async def main(args: argparse.Namespace) -> None:
    from opensearchpy import AsyncOpenSearch

    # ── Initialise clients ─────────────────────────────────────────────
    es_kwargs: dict[str, Any] = {
        "hosts": [settings.es_host],
        "timeout": 60,
    }
    if settings.es_username:
        es_kwargs["http_auth"] = (settings.es_username, settings.es_password)

    es = AsyncOpenSearch(**es_kwargs)
    llm_client = LLMClient()
    neo4j_client = Neo4jClient.from_settings()
    builder = GraphBuilder(
        llm_client,
        neo4j_client,
        max_content_chars=settings.graph_max_content_chars,
    )

    try:
        # ── Resolve doc list ───────────────────────────────────────────
        if args.doc_ids:
            doc_ids_set = {d.strip() for d in args.doc_ids.split(",") if d.strip()}
            all_docs: list[dict[str, Any]] = []
            for doc_id in doc_ids_set:
                try:
                    resp = await es.get(index=settings.es_meta_index, id=doc_id)
                    all_docs.append(resp["_source"])
                except Exception:
                    logger.warning("doc_not_found", doc_id=doc_id)
        else:
            logger.info("scrolling_all_docs")
            all_docs = await scroll_all_docs(es, limit=args.limit)

        total = len(all_docs)
        logger.info("bulk_graph_start", total=total)
        print(f"\nTotal documents: {total}")

        if args.dry_run:
            print("Dry-run mode -- listing doc IDs only:")
            for doc in all_docs[:50]:
                print(f"  {doc.get('doc_id', '?')}  {doc.get('title', '')[:60]}")
            if total > 50:
                print(f"  ... and {total - 50} more")
            return

        # ── Process with bounded concurrency ───────────────────────────
        stats: dict[str, int] = {"success": 0, "failed": 0, "skipped": 0}
        sem = asyncio.Semaphore(args.concurrency)
        start_time = time.time()

        async def guarded(doc: dict[str, Any]) -> None:
            async with sem:
                await process_doc(
                    doc,
                    builder,
                    es,
                    neo4j_client,
                    skip_existing=args.skip_existing,
                    stats=stats,
                )
                done = stats["success"] + stats["failed"] + stats["skipped"]
                if done % 10 == 0 or done == total:
                    elapsed = time.time() - start_time
                    rate = done / elapsed if elapsed > 0 else 0
                    print(
                        f"\r  [OK] {stats['success']} done  "
                        f"[FAIL] {stats['failed']} failed  "
                        f"[SKIP] {stats['skipped']} skipped  "
                        f"({rate:.1f} docs/s)",
                        end="",
                        flush=True,
                    )

        await asyncio.gather(*[guarded(doc) for doc in all_docs])

        elapsed = time.time() - start_time
        print(f"\n\nCompleted in {elapsed:.1f}s")
        print(json.dumps(stats, ensure_ascii=False, indent=2))

    finally:
        await es.close()
        await llm_client.close()
        await neo4j_client.close()


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Bulk build knowledge graphs from ES documents")
    parser.add_argument(
        "--skip-existing",
        action="store_true",
        default=False,
        help="Skip documents that already have graph nodes",
    )
    parser.add_argument(
        "--concurrency",
        type=int,
        default=2,
        help="Number of parallel LLM calls (default: 2)",
    )
    parser.add_argument(
        "--limit",
        type=int,
        default=None,
        help="Maximum number of documents to process",
    )
    parser.add_argument(
        "--doc-ids",
        type=str,
        default=None,
        help="Comma-separated list of specific doc_ids to process",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        default=False,
        help="List documents without building graphs",
    )
    return parser.parse_args()


if __name__ == "__main__":
    asyncio.run(main(parse_args()))
