"""
回填已入库文档的元数据（title, doc_type, issuing_org 等）。

对所有 title 为空的 content_hash 分组，取第一个 chunk 的 content 文本，
用 MetadataExtractor 提取元数据后，patch 到该 content_hash 的所有 chunks。

用法:
    cd backend
    python -m scripts.backfill_metadata
"""

from __future__ import annotations

import asyncio
import json
import sys

import requests

sys.path.insert(0, ".")

from app.config import settings  # noqa: E402


ES_URL = settings.es_host
CHUNK_INDEX = settings.es_chunk_index


def get_empty_metadata_groups() -> list[dict]:
    """获取 title 为空的所有 content_hash，及其第一个 chunk 的 content。"""
    body = {
        "size": 0,
        "query": {
            "bool": {
                "should": [
                    {"term": {"title.keyword": ""}},
                    {"bool": {"must_not": {"exists": {"field": "title"}}}},
                ],
                "minimum_should_match": 1,
            }
        },
        "aggs": {
            "by_hash": {
                "terms": {"field": "content_hash", "size": 500},
                "aggs": {
                    "first_chunk": {
                        "top_hits": {
                            "size": 1,
                            "sort": [{"chunk_index": "asc"}],
                            "_source": ["content", "content_hash"],
                        }
                    }
                },
            }
        },
    }
    r = requests.post(f"{ES_URL}/{CHUNK_INDEX}/_search", json=body)
    r.raise_for_status()
    data = r.json()
    groups = []
    for bucket in data["aggregations"]["by_hash"]["buckets"]:
        hit = bucket["first_chunk"]["hits"]["hits"][0]["_source"]
        groups.append({
            "content_hash": hit["content_hash"],
            "content": hit.get("content", ""),
            "chunk_count": bucket["doc_count"],
        })
    return groups


async def extract_and_patch(groups: list[dict]) -> None:
    from app.core.metadata_extractor import MetadataExtractor
    from app.infrastructure.llm_client import LLMClient

    llm = LLMClient()
    extractor = MetadataExtractor(llm)

    patched = 0
    failed = 0

    for i, group in enumerate(groups):
        content_hash = group["content_hash"]
        content = group["content"]
        if not content.strip():
            print(f"[{i+1}/{len(groups)}] SKIP {content_hash[:12]}... (empty content)")
            failed += 1
            continue

        # Combine first few chunks for better extraction
        # (using just the first chunk content here)
        try:
            meta = await extractor.extract(content)
        except Exception as e:
            print(f"[{i+1}/{len(groups)}] FAIL {content_hash[:12]}... extract error: {e}")
            failed += 1
            continue

        title = meta.get("title", "")
        if not title:
            print(f"[{i+1}/{len(groups)}] WARN {content_hash[:12]}... no title extracted")

        # Build update fields
        fields = {}
        for key in ["title", "doc_number", "issuing_org", "doc_type",
                     "publish_date", "signer", "subject_words",
                     "knowledge_category"]:
            val = meta.get(key)
            if val:
                fields[key] = val

        if not fields:
            print(f"[{i+1}/{len(groups)}] SKIP {content_hash[:12]}... no fields extracted")
            failed += 1
            continue

        # Patch all chunks
        assignments = [f"ctx._source.{k} = params.fields.{k}" for k in fields]
        body = {
            "query": {"term": {"content_hash": content_hash}},
            "script": {
                "lang": "painless",
                "source": "; ".join(assignments),
                "params": {"fields": fields},
            },
        }
        r = requests.post(
            f"{ES_URL}/{CHUNK_INDEX}/_update_by_query?refresh=true&conflicts=proceed",
            json=body,
        )
        r.raise_for_status()
        updated = r.json().get("updated", 0)
        patched += 1
        print(
            f"[{i+1}/{len(groups)}] OK {content_hash[:12]}... "
            f"title=\"{title[:30]}\" org=\"{meta.get('issuing_org','')}\" "
            f"type=\"{meta.get('doc_type','')}\" updated={updated} chunks"
        )

    # Also patch meta index
    print(f"\n--- Done: {patched} patched, {failed} failed/skipped ---")


def main():
    print("=== Backfill document metadata ===")
    groups = get_empty_metadata_groups()
    print(f"Found {len(groups)} content groups with empty metadata\n")
    if not groups:
        print("Nothing to do.")
        return
    asyncio.run(extract_and_patch(groups))


if __name__ == "__main__":
    main()
