"""Celery tasks for notebook source ingestion.

Notebook 来源入库的 Celery 异步任务。
使用现有 IngestPipeline 处理上传/粘贴的来源文件，
并设置 NB_{notebook_id} ACL 以实现知识隔离。
"""

from __future__ import annotations

import asyncio
from typing import Any

from celery.exceptions import MaxRetriesExceededError, Retry
from celery.utils.log import get_task_logger

from app.tasks.celery_app import celery_app

logger = get_task_logger(__name__)


def _run_async(coro):
    loop = asyncio.new_event_loop()
    try:
        return loop.run_until_complete(coro)
    finally:
        loop.close()


@celery_app.task(
    name="tasks.notebook_ingest_source",
    bind=True,
    max_retries=2,
    default_retry_delay=30,
    acks_late=True,
    track_started=True,
)
def notebook_ingest_source_task(
    self,
    source_id: str,
    notebook_id: str,
    user_id: str,
    doc_id: str,
    file_path: str,
    title: str = "",
    metadata: dict[str, Any] | None = None,
) -> dict[str, Any]:
    """Celery task: ingest a notebook source file into ES with NB_ ACL.

    Args:
        source_id: Notebook source record ID.
        notebook_id: Parent notebook ID.
        user_id: Owner user ID.
        doc_id: ES document ID for the ingested content.
        file_path: Absolute path to the file to ingest.
        title: Document title.
        metadata: Additional metadata dict.
    """
    metadata = metadata or {}
    attempt = self.request.retries + 1
    logger.info(
        "Starting notebook source ingest source_id=%s, notebook_id=%s, attempt=%s",
        source_id, notebook_id, attempt,
    )

    self.update_state(
        state="PROCESSING",
        meta={"source_id": source_id, "notebook_id": notebook_id, "step": "initializing"},
    )

    async def _run():
        from pathlib import Path

        from app.config import settings
        from app.core.ingest_pipeline import create_pipeline
        from app.infrastructure.es_client import ESClient
        from app.infrastructure.mysql_client import MySQLClient
        from app.infrastructure.notebook_store import NotebookStore
        from app.infrastructure.redis_client import RedisClient
        from opensearchpy import AsyncOpenSearch

        _redis_client = None
        _mysql_client = None
        _es_client = None
        pipeline = None

        try:
            # Create clients
            _redis_client = RedisClient.from_settings()

            es_kwargs: dict[str, Any] = {"hosts": [settings.es_host], "timeout": 60}
            if settings.es_username:
                es_kwargs["http_auth"] = (settings.es_username, settings.es_password)
            _es_client = ESClient(AsyncOpenSearch(**es_kwargs), redis_client=_redis_client)

            if settings.mysql_enabled and settings.mysql_host and settings.mysql_database:
                try:
                    _mysql_client = await MySQLClient.from_settings()
                except Exception as exc:
                    logger.warning("notebook_task_mysql_init_failed", error=str(exc))
                    raise RuntimeError(f"MySQL 初始化失败: {exc}")

            if _mysql_client is None:
                raise RuntimeError("MySQL is required for notebook source tracking")

            nb_store = NotebookStore(_mysql_client)

            # Update source status to processing
            await nb_store.update_source(
                source_id, notebook_id, user_id,
                ingest_status="processing", ingest_task_id=self.request.id,
            )

            # Set up metadata with NB_ ACL + user's own token for preview access
            ingest_metadata = {
                **metadata,
                "title": title,
                "acl_ids": [f"NB_{notebook_id}", user_id],
                "knowledge_category": "notebook",
                "notebook_id": notebook_id,
            }

            # Resolve file path — file_path is relative to file_storage_path
            p = Path(file_path)
            resolved = str(p) if p.is_absolute() else str(settings.file_storage_path / p)

            # Pre-check: if content_hash exists in meta but chunks are gone,
            # clear stale meta BEFORE running pipeline so dedup won't trigger
            import hashlib as _hashlib
            _pre_hash = _hashlib.sha256()
            with open(resolved, "rb") as _hf:
                for _chunk in iter(lambda: _hf.read(8192), b""):
                    _pre_hash.update(_chunk)
            _pre_content_hash = _pre_hash.hexdigest()

            _pre_meta = await _es_client.raw.search(
                index=settings.es_meta_index,
                body={"size": 1, "query": {"bool": {"must": [
                    {"term": {"content_hash": _pre_content_hash}},
                    {"term": {"status": "completed"}},
                ]}}},
            )
            _pre_meta_raw = _pre_meta if isinstance(_pre_meta, dict) else _pre_meta.body
            _has_existing_meta = _pre_meta_raw.get("hits", {}).get("total", {})
            if isinstance(_has_existing_meta, dict):
                _has_existing_meta = _has_existing_meta.get("value", 0)

            if _has_existing_meta > 0:
                _chunk_check = await _es_client.raw.search(
                    index=settings.es_chunk_index,
                    body={"size": 0, "query": {"term": {"content_hash": _pre_content_hash}}},
                )
                _chunk_check_raw = _chunk_check if isinstance(_chunk_check, dict) else _chunk_check.body
                _chunk_count = _chunk_check_raw.get("hits", {}).get("total", {})
                if isinstance(_chunk_count, dict):
                    _chunk_count = _chunk_count.get("value", 0)

                if _chunk_count == 0:
                    logger.info(
                        "notebook_stale_dedup: meta exists but chunks=0, clearing meta. hash=%s",
                        _pre_content_hash[:12],
                    )
                    await _es_client.raw.delete_by_query(
                        index=settings.es_meta_index,
                        body={"query": {"term": {"content_hash": _pre_content_hash}}},
                        conflicts="proceed",
                        refresh=True,
                    )

            # Run ingest pipeline
            pipeline = create_pipeline(redis_client=_redis_client, mysql_client=_mysql_client)
            result = await pipeline.ingest_document(
                doc_id=doc_id,
                file_path=resolved,
                metadata=ingest_metadata,
            )

            status = result.get("status", "failed")
            if status not in ("completed", "partial_failed"):
                error_msg = result.get("error", "入库失败")
                await nb_store.update_source(
                    source_id, notebook_id, user_id,
                    ingest_status="failed", ingest_error=error_msg,
                )
                return {"source_id": source_id, "status": "failed", "error": error_msg}

            content_hash = result.get("content_hash")
            await nb_store.update_source(
                source_id, notebook_id, user_id,
                ingest_status="completed",
                content_hash=content_hash,
                doc_id=doc_id,
            )

            # Ensure chunks have this doc_id + NB_ ACL + user token in their arrays
            if content_hash:
                nb_acl = f"NB_{notebook_id}"
                try:
                    await _es_client.raw.update_by_query(
                        index=settings.es_chunk_index,
                        body={
                            "query": {"term": {"content_hash": content_hash}},
                            "script": {
                                "lang": "painless",
                                "source": (
                                    "if (ctx._source.doc_ids == null) { ctx._source.doc_ids = new ArrayList(); }"
                                    "if (!ctx._source.doc_ids.contains(params.did)) { ctx._source.doc_ids.add(params.did); }"
                                    "if (ctx._source.acl_ids == null) { ctx._source.acl_ids = new ArrayList(); }"
                                    "if (!ctx._source.acl_ids.contains(params.acl)) { ctx._source.acl_ids.add(params.acl); }"
                                    "if (!ctx._source.acl_ids.contains(params.uid)) { ctx._source.acl_ids.add(params.uid); }"
                                ),
                                "params": {"did": doc_id, "acl": nb_acl, "uid": user_id},
                            },
                        },
                        conflicts="proceed",
                        refresh=True,
                    )
                    logger.info("notebook_chunks_ensured doc_id=%s content_hash=%s", doc_id, content_hash[:12])
                except Exception as exc:
                    logger.warning("notebook_chunks_update_failed doc_id=%s: %s", doc_id, exc)

                # Best-effort: generate source summary
                try:
                    from app.core.embedding import EmbeddingService
                    from app.core.graph_query_service import GraphQueryService
                    from app.core.notebook_service import NotebookService
                    from app.infrastructure.embedding_client import EmbeddingClient
                    from app.infrastructure.llm_client import LLMClient
                    from app.infrastructure.neo4j_client import Neo4jClient

                    _llm = LLMClient()
                    _embedding_client = EmbeddingClient()
                    _neo4j = Neo4jClient.from_settings()
                    _embedding_svc = EmbeddingService(_embedding_client)
                    _graph_svc = GraphQueryService(_neo4j)

                    nb_service = NotebookService(
                        store=nb_store,
                        es_client=_es_client,
                        embedding_service=_embedding_svc,
                        graph_service=_graph_svc,
                        llm_client=_llm,
                    )
                    await nb_service.generate_source_summary(source_id, notebook_id, user_id)

                    await _llm.close()
                    await _embedding_client.close()
                    await _neo4j.close()
                except Exception as exc:
                    logger.warning("notebook_source_summary_failed", error=str(exc))

                return {"source_id": source_id, "status": "completed"}
            else:
                error_msg = result.get("error", "入库失败")
                await nb_store.update_source(
                    source_id, notebook_id, user_id,
                    ingest_status="failed", ingest_error=error_msg,
                )
                return {"source_id": source_id, "status": "failed", "error": error_msg}

        except Exception as exc:
            # Update source status to failed
            if _mysql_client is not None:
                try:
                    nb_store = NotebookStore(_mysql_client)
                    await nb_store.update_source(
                        source_id, notebook_id, user_id,
                        ingest_status="failed", ingest_error=str(exc),
                    )
                except Exception as cleanup_exc:
                    logger.warning("notebook_source_status_update_failed", error=str(cleanup_exc))
            raise

        finally:
            if pipeline is not None:
                await pipeline.close()
            if _es_client is not None:
                await _es_client.close()
            if _redis_client is not None:
                await _redis_client.close()
            if _mysql_client is not None:
                await _mysql_client.close()

    try:
        result = _run_async(_run())
        if result.get("status") == "failed":
            error_msg = result.get("error", "Unknown error")
            logger.error("Notebook source ingest failed: %s", error_msg)
            if self.request.retries < self.max_retries:
                raise self.retry(
                    exc=RuntimeError(error_msg),
                    countdown=30 * (self.request.retries + 1),
                )
        return result

    except (Retry, MaxRetriesExceededError):
        raise
    except Exception as exc:
        logger.error("Unexpected error in notebook source ingest: %s", exc)
        if self.request.retries < self.max_retries:
            raise self.retry(exc=exc, countdown=30 * (self.request.retries + 1))
        return {"source_id": source_id, "status": "failed", "error": str(exc)}
