"""Big Lobster-compatible KB adapter endpoints.

This module exposes the narrow KB API contract already consumed by
``zm-ai-server``:

* GET  /lobster-kb/scopes
* POST /lobster-kb/search
* POST /lobster-kb/detail

The adapter keeps zm-rag as the authority for retrieval, graph evidence, and
ACL filtering while avoiding any direct OpenSearch/Neo4j coupling from
lobster.
"""

from __future__ import annotations

import re
from typing import Annotated, Any
from urllib.parse import urlparse

from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from opensearchpy.exceptions import NotFoundError
from pydantic import BaseModel, Field

from app.api.deps import UserContext
from app.config import settings
from app.core.embedding import EmbeddingService
from app.core.graph_query_service import GraphQueryService
from app.core.permission import PermissionService
from app.core.research_engine import ResearchEngine
from app.infrastructure.embedding_client import EmbeddingClient
from app.infrastructure.es_client import ESClient
from app.infrastructure.llm_client import LLMClient
from app.infrastructure.mysql_client import MySQLClient
from app.infrastructure.neo4j_client import Neo4jClient
from app.infrastructure.redis_client import RedisClient
from app.infrastructure.session_store import build_research_session_store
from app.utils.logger import get_logger

logger = get_logger(__name__)

router = APIRouter(prefix="/lobster-kb", tags=["lobster-kb"])

_DETAIL_CHUNK_SOURCE_FIELDS = [
    "content",
    "heading_hierarchy",
    "page_number",
    "page_numbers",
    "chunk_index",
]
_DETAIL_SCAN_BATCH_SIZE = 500
_HIGHLIGHT_EM_TAG_RE = re.compile(r"</?em\b[^>]*>", re.IGNORECASE)


class LobsterKbSearchRequest(BaseModel):
    query: str = Field(..., min_length=1, max_length=2000)
    maxResults: int = Field(10, ge=1, le=20)
    scopeIds: list[str] | None = None
    llm: bool = Field(False)
    escapeHtml: bool = Field(True)
    userId: str | None = None
    deptId: str | None = None
    orgId: str | None = None
    areaId: str | None = None
    roleIds: list[str] | None = None


class LobsterKbDetailRequest(BaseModel):
    docId: str = Field(..., min_length=1)
    offset: int = Field(0, ge=0)
    chunkCharOffset: int = Field(0, ge=0)
    limit: int = Field(8, ge=1, le=50)
    maxChars: int | None = Field(None, ge=1000, le=50_000)
    query: str | None = Field(None, max_length=500)
    sectionHint: str | None = Field(None, max_length=500)
    userId: str | None = None
    deptId: str | None = None
    orgId: str | None = None
    areaId: str | None = None
    roleIds: list[str] | None = None


def _ensure_prefix(value: str | None, prefix: str) -> str:
    if value is None:
        return ""
    value = value.strip()
    if not value:
        return ""
    return value if value.startswith(prefix) else f"{prefix}{value}"


def _map_lobster_org_to_office_id(org_id: str | None) -> str:
    """Map lobster orgId to zm-rag office_id.

    The integration owner confirmed that lobster ``orgId`` is the zm-rag
    ``office_id``. Keep the mapping in one helper so future org-tree changes do
    not leak into request handling.
    """
    return _ensure_prefix(org_id, "O_")


def _map_lobster_area_id(request: Request) -> str:
    """TODO: map lobster user/org context to zm-rag area_id.

    Current rollout intentionally leaves area mapping empty. Implement the
    project-specific hierarchy lookup here when area-level ACL needs to be
    enabled for lobster callers.
    """
    _ = request.headers.get("X-Area-Id")
    return ""


def _map_lobster_role_ids(request: Request) -> list[str]:
    """TODO: map lobster roles to zm-rag role_ids.

    Current rollout intentionally leaves role mapping empty. Implement the
    project-specific role translation here when role-level ACL needs to be
    enabled for lobster callers.
    """
    _ = request.headers.get("X-Role-Ids")
    return []


def _extract_bearer_token(authorization: str | None) -> str:
    if not authorization:
        return ""
    value = authorization.strip()
    if value.lower().startswith("bearer "):
        return value[7:].strip()
    return value


async def get_lobster_user(
    request: Request,
    authorization: Annotated[str | None, Header()] = None,
    x_user_id: Annotated[str | None, Header(alias="X-User-Id")] = None,
    x_dept_id: Annotated[str | None, Header(alias="X-Dept-Id")] = None,
    x_org_id: Annotated[str | None, Header(alias="X-Org-Id")] = None,
) -> UserContext:
    """Authenticate lobster service calls and build a zm-rag user context."""
    expected = settings.lobster_kb_api_key
    if not expected:
        raise HTTPException(
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            detail="lobster KB adapter is not configured",
        )

    actual = _extract_bearer_token(authorization)
    if actual != expected:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="invalid lobster KB service token",
            headers={"WWW-Authenticate": "Bearer"},
        )

    user_id = _ensure_prefix(x_user_id, "U_")
    if not user_id:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="X-User-Id header is required",
        )

    return UserContext(
        user_id=user_id,
        office_id=_map_lobster_org_to_office_id(x_org_id),
        dept_id=_ensure_prefix(x_dept_id, "D_"),
        area_id=_map_lobster_area_id(request),
        role_ids=_map_lobster_role_ids(request),
    )


async def _build_research_engine(request: Request) -> ResearchEngine:
    es_client: ESClient = request.app.state.es_client
    neo4j_client: Neo4jClient = request.app.state.neo4j_client
    redis_client: RedisClient = request.app.state.redis_client
    mysql_client: MySQLClient | None = getattr(request.app.state, "mysql_client", None)
    embedding_client: EmbeddingClient = request.app.state.embedding_client
    llm_client: LLMClient = request.app.state.llm_client

    embedding_svc = EmbeddingService(embedding_client, redis=redis_client.raw)
    graph_svc = GraphQueryService(neo4j_client)
    session_store = build_research_session_store(
        redis_client=redis_client,
        mysql_client=mysql_client,
    )
    return ResearchEngine(
        es_client=es_client,
        embedding_service=embedding_svc,
        graph_service=graph_svc,
        llm_client=llm_client,
        session_store=session_store,
    )


def _first_passage(doc: dict[str, Any]) -> str:
    passages = doc.get("passages")
    if isinstance(passages, list):
        for passage in passages:
            text = str(passage or "").strip()
            if text:
                return text
    summary = str(doc.get("summary") or "").strip()
    return summary


def _strip_highlight_tags(value: str) -> str:
    return _HIGHLIGHT_EM_TAG_RE.sub("", value)


def _safe_source_url(value: Any) -> str:
    if value is None:
        return ""
    url = str(value).strip()
    if not url:
        return ""
    parsed = urlparse(url)
    if parsed.scheme.lower() not in {"http", "https"}:
        return ""
    if not parsed.netloc:
        return ""
    return url


def _source_url_from_doc(doc: dict[str, Any]) -> str:
    source_metadata = doc.get("source_metadata") or {}
    if not isinstance(source_metadata, dict):
        source_metadata = {}
    return _safe_source_url(
        doc.get("source_url")
        or doc.get("sourceUrl")
        or doc.get("url")
        or source_metadata.get("source_url")
        or source_metadata.get("sourceUrl")
        or source_metadata.get("url")
    )


def _hit_from_qa_doc(doc: dict[str, Any], *, escape_html: bool = True) -> dict[str, Any]:
    doc_id = str(doc.get("doc_id") or "")
    source_type = str(doc.get("source_type") or "search")
    source_label = doc.get("source_label")
    title = str(doc.get("title") or doc.get("summary") or doc_id)
    snippet = _first_passage(doc)
    if escape_html:
        title = _strip_highlight_tags(title)
        snippet = _strip_highlight_tags(snippet)
    source_url = _source_url_from_doc(doc)
    return {
        "docId": doc_id,
        "title": title,
        "snippet": snippet,
        "score": doc.get("score") or 0,
        "source": source_type,
        "scopeName": source_label or source_type,
        "url": source_url,
        "sourceUrl": source_url,
        "updatedAt": doc.get("publish_date") or "",
    }


@router.get("/scopes")
async def scopes(
    user: Annotated[UserContext, Depends(get_lobster_user)],
) -> dict[str, Any]:
    """Return visible scopes.

    Scope support is intentionally deferred. Returning an empty list makes the
    existing lobster frontend use its "all KB" default.
    """
    return {"scopes": []}


@router.post("/search")
async def search(
    body: LobsterKbSearchRequest,
    user: Annotated[UserContext, Depends(get_lobster_user)],
    request: Request,
) -> dict[str, Any]:
    """Search zm-rag and return lobster KB ``hits``."""
    # Lobster can send a different orgId for the same user_id. PermissionService
    # caches by user_id only, so bypass the cache here to avoid stale office_id
    # tokens leaking across org switches.
    perm = await PermissionService(redis_client=None).resolve(user)
    engine = await _build_research_engine(request)

    # scopeIds are accepted for forward compatibility, but intentionally unused
    # in P0. zm-rag ACL remains the authoritative visibility filter.
    result = await engine.qa_search(body.query, perm, seed_doc_ids=None, llm=body.llm)
    docs = result.get("documents") or []
    hits = [
        _hit_from_qa_doc(doc, escape_html=body.escapeHtml)
        for doc in docs[: body.maxResults]
        if str(doc.get("doc_id") or "").strip()
    ]
    context_text = result.get("context_text", "")
    graph_evidence_text = result.get("graph_evidence_text", "")
    guide_evidence_text = result.get("guide_evidence_text", "")
    if body.escapeHtml:
        context_text = _strip_highlight_tags(str(context_text))
        graph_evidence_text = _strip_highlight_tags(str(graph_evidence_text))
        guide_evidence_text = _strip_highlight_tags(str(guide_evidence_text))

    logger.info(
        "lobster_kb_search",
        user_id=user.user_id,
        query_len=len(body.query),
        hit_count=len(hits),
    )
    return {
        "hits": hits,
        "contextText": context_text,
        "graphEvidenceText": graph_evidence_text,
        "guideEvidenceText": guide_evidence_text,
    }


@router.post("/detail")
async def detail(
    body: LobsterKbDetailRequest,
    user: Annotated[UserContext, Depends(get_lobster_user)],
    request: Request,
) -> dict[str, Any]:
    """Return a markdown-ish document body for lobster's citation drawer."""
    es_client: ESClient = request.app.state.es_client
    # Keep detail ACL evaluation tied to the current Lobster org context; see
    # the search endpoint comment for why this adapter bypasses permission
    # caching.
    perm = await PermissionService(redis_client=None).resolve(user)

    try:
        meta_resp = await es_client.raw.get(
            index=settings.es_meta_index,
            id=body.docId,
        )
    except NotFoundError:
        raise HTTPException(status_code=404, detail=f"Document {body.docId} not found")

    raw_meta = meta_resp if isinstance(meta_resp, dict) else meta_resp.body
    meta = raw_meta.get("_source", {})
    doc_acl_ids = meta.get("acl_ids") or []
    if not perm.has_acl_access(doc_acl_ids):
        raise HTTPException(status_code=403, detail="No permission to view this document")

    content_hash = str(meta.get("content_hash") or "")
    if not content_hash:
        raise HTTPException(status_code=404, detail="Document content is not available")

    total_chunks = await _count_detail_chunks(es_client, content_hash)
    if total_chunks <= 0:
        raise HTTPException(status_code=404, detail="Document chunks are not available")

    page = await _locate_detail_page(
        es_client,
        content_hash,
        offset=body.offset,
        query=body.query,
        section_hint=body.sectionHint,
        allow_locate=body.offset == 0 and body.chunkCharOffset == 0,
    )
    selected = await _fetch_detail_chunks_by_offset(
        es_client,
        content_hash,
        offset=page["startOffset"],
        limit=body.limit,
    )
    content, rendered_count, completed_count, char_truncated, next_char_offset = _render_detail_markdown_page(
        meta,
        selected,
        body.maxChars or settings.lobster_kb_detail_max_chars,
        chunk_char_offset=body.chunkCharOffset,
    )

    next_offset = page["startOffset"] + completed_count
    has_more = next_char_offset is not None or next_offset < total_chunks
    source_url = _source_url_from_doc(meta)

    return {
        "docId": body.docId,
        "content": content,
        "url": source_url,
        "sourceUrl": source_url,
        "sourceType": meta.get("source_system") or meta.get("source_type") or "",
        "offset": page["startOffset"],
        "requestedOffset": body.offset,
        "chunkCharOffset": body.chunkCharOffset,
        "limit": body.limit,
        "maxChars": body.maxChars or settings.lobster_kb_detail_max_chars,
        "returnedChunks": rendered_count,
        "completedChunks": completed_count,
        "totalChunks": total_chunks,
        "nextOffset": next_offset if has_more else None,
        "nextChunkCharOffset": next_char_offset,
        "hasMore": has_more,
        "truncated": bool(char_truncated or has_more),
        "locatedBy": page.get("locatedBy"),
        "matchedOffset": page.get("matchedOffset"),
        "matchedChunkIndex": page.get("matchedChunkIndex"),
    }


async def _count_detail_chunks(es_client: ESClient, content_hash: str) -> int:
    resp = await es_client.raw.count(
        index=settings.es_chunk_index,
        body=_build_detail_chunk_count_body(content_hash),
    )
    raw = resp if isinstance(resp, dict) else resp.body
    try:
        return int(raw.get("count") or 0)
    except (TypeError, ValueError):
        return 0


async def _fetch_detail_chunks_by_offset(
    es_client: ESClient,
    content_hash: str,
    *,
    offset: int,
    limit: int,
) -> list[dict[str, Any]]:
    safe_offset = max(0, offset)
    safe_limit = max(1, min(limit, 50))
    selected: list[dict[str, Any]] = []
    seen = 0
    search_after: list[Any] | None = None

    while len(selected) < safe_limit:
        resp = await es_client.raw.search(
            index=settings.es_chunk_index,
            body=_build_detail_chunk_search_body(
                content_hash,
                size=max(_DETAIL_SCAN_BATCH_SIZE, safe_limit),
                search_after=search_after,
            ),
        )
        raw = resp if isinstance(resp, dict) else resp.body
        hits = raw.get("hits", {}).get("hits", [])
        if not hits:
            break

        for hit in hits:
            if seen >= safe_offset and len(selected) < safe_limit:
                selected.append(hit)
            seen += 1
            if len(selected) >= safe_limit:
                break

        search_after = hits[-1].get("sort")
        if not search_after:
            break

    return selected


async def _locate_detail_page(
    es_client: ESClient,
    content_hash: str,
    *,
    offset: int,
    query: str | None,
    section_hint: str | None,
    allow_locate: bool,
) -> dict[str, Any]:
    page: dict[str, Any] = {
        "startOffset": max(0, offset),
        "locatedBy": None,
        "matchedOffset": None,
        "matchedChunkIndex": None,
    }
    terms = _detail_terms(query, section_hint)
    if not allow_locate or not terms:
        return page

    best_score = 0
    best_offset = 0
    best_hit: dict[str, Any] | None = None
    seen = 0
    search_after: list[Any] | None = None

    while True:
        resp = await es_client.raw.search(
            index=settings.es_chunk_index,
            body=_build_detail_chunk_search_body(
                content_hash,
                size=_DETAIL_SCAN_BATCH_SIZE,
                search_after=search_after,
            ),
        )
        raw = resp if isinstance(resp, dict) else resp.body
        hits = raw.get("hits", {}).get("hits", [])
        if not hits:
            break

        for hit in hits:
            score = _score_detail_chunk(hit, terms)
            if score > best_score:
                best_score = score
                best_offset = seen
                best_hit = hit
            seen += 1

        search_after = hits[-1].get("sort")
        if not search_after:
            break

    if best_hit is not None and best_score > 0:
        src = best_hit.get("_source", {})
        page.update(
            {
                "startOffset": max(0, best_offset - 1),
                "locatedBy": "sectionHint" if section_hint and section_hint.strip() else "query",
                "matchedOffset": best_offset,
                "matchedChunkIndex": src.get("chunk_index"),
            }
        )
    return page


def _build_detail_chunk_count_body(content_hash: str) -> dict[str, Any]:
    return {"query": {"term": {"content_hash": content_hash}}}


def _build_detail_chunk_search_body(
    content_hash: str,
    *,
    size: int = _DETAIL_SCAN_BATCH_SIZE,
    search_after: list[Any] | None = None,
) -> dict[str, Any]:
    """Build chunk scans used after doc-level ACL already passed.

    Chunk ACL is merged across all documents sharing the same content_hash.
    Applying it again here can reject an otherwise visible document when the
    requested doc meta is public but a shared sibling has private ACL tokens.
    """
    body: dict[str, Any] = {
        "size": max(1, min(size, _DETAIL_SCAN_BATCH_SIZE)),
        "query": {"term": {"content_hash": content_hash}},
        "sort": [{"chunk_index": {"order": "asc", "unmapped_type": "integer"}}],
        "_source": _DETAIL_CHUNK_SOURCE_FIELDS,
    }
    if search_after:
        body["search_after"] = search_after
    return body


def _chunk_sort_key(hit: dict[str, Any]) -> tuple[int, str]:
    src = hit.get("_source", {})
    idx = src.get("chunk_index")
    try:
        numeric_idx = int(idx)
    except (TypeError, ValueError):
        numeric_idx = 0
    return numeric_idx, str(hit.get("_id") or "")


def _chunk_text(hit: dict[str, Any]) -> str:
    src = hit.get("_source", {})
    heading = src.get("heading_hierarchy") or []
    heading_text = " ".join(str(x) for x in heading if x) if isinstance(heading, list) else str(heading)
    return "\n".join(
        part for part in [heading_text, str(src.get("content") or "")] if part
    )


def _detail_terms(query: str | None, section_hint: str | None) -> list[str]:
    raw_parts = [part.strip().lower() for part in (section_hint, query) if part and part.strip()]
    terms: list[str] = []
    for raw in raw_parts:
        terms.append(raw)
        for token in re.findall(r"[\w\u4e00-\u9fff]{2,}", raw):
            terms.append(token)
    return list(dict.fromkeys(term for term in terms if term))


def _score_detail_chunk(hit: dict[str, Any], terms: list[str]) -> int:
    if not terms:
        return 0
    text = _chunk_text(hit).lower()
    score = 0
    for term in terms:
        if term in text:
            score += 10 if len(term) > 6 else 4
    return score


def _select_detail_chunks(
    chunk_hits: list[dict[str, Any]],
    *,
    offset: int,
    limit: int,
    query: str | None = None,
    section_hint: str | None = None,
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
    ordered = sorted(chunk_hits, key=_chunk_sort_key)
    total = len(ordered)
    safe_limit = max(1, min(limit, 50))
    start = max(0, min(offset, total))
    page: dict[str, Any] = {
        "startOffset": start,
        "locatedBy": None,
        "matchedOffset": None,
        "matchedChunkIndex": None,
    }

    terms = _detail_terms(query, section_hint)
    if start == 0 and terms:
        scored = [
            (idx, _score_detail_chunk(hit, terms), hit)
            for idx, hit in enumerate(ordered)
        ]
        best = max(scored, key=lambda item: item[1], default=None)
        if best and best[1] > 0:
            best_offset, _, best_hit = best
            start = max(0, best_offset - 1)
            src = best_hit.get("_source", {})
            page.update(
                {
                    "startOffset": start,
                    "locatedBy": "sectionHint" if section_hint and section_hint.strip() else "query",
                    "matchedOffset": best_offset,
                    "matchedChunkIndex": src.get("chunk_index"),
                }
            )

    return ordered[start : start + safe_limit], page


def _render_detail_markdown_page(
    meta: dict[str, Any],
    chunk_hits: list[dict[str, Any]],
    max_chars: int,
    *,
    chunk_char_offset: int = 0,
) -> tuple[str, int, int, bool, int | None]:
    rendered_chunks = 0
    completed_chunks = 0
    truncated = False
    parts = _detail_markdown_header(meta)
    max_chars = max(1, max_chars)

    for pos, hit in enumerate(chunk_hits):
        content_offset = max(0, chunk_char_offset) if pos == 0 else 0
        chunk_parts = _detail_chunk_markdown_parts(hit, content_offset=content_offset)
        if not chunk_parts:
            completed_chunks += 1
            continue
        candidate_parts = parts + chunk_parts + [""]
        candidate = "\n".join(candidate_parts).strip()
        if len(candidate) > max_chars:
            if rendered_chunks == 0:
                rendered, next_char_offset = _render_partial_first_chunk(
                    parts,
                    hit,
                    max_chars,
                    content_offset,
                )
                rendered_count = 1 if next_char_offset != content_offset else 0
                completed_count = completed_chunks + 1 if next_char_offset is None and rendered_count else completed_chunks
                return rendered, rendered_count, completed_count, True, next_char_offset
            truncated = True
            break
        parts = candidate_parts
        rendered_chunks += 1
        completed_chunks += 1

    return "\n".join(parts).strip(), rendered_chunks, completed_chunks, truncated, None


def _detail_markdown_header(meta: dict[str, Any]) -> list[str]:
    parts: list[str] = []
    title = str(meta.get("title") or "Knowledge Base Document")
    parts.append(f"# {title}\n")

    meta_lines: list[str] = []
    for label, key in (
        ("Doc Number", "doc_number"),
        ("Issuing Org", "issuing_org"),
        ("Doc Type", "doc_type"),
        ("Publish Date", "publish_date"),
        ("Knowledge Category", "knowledge_category"),
    ):
        value = meta.get(key)
        if value:
            meta_lines.append(f"- {label}: {value}")
    if meta_lines:
        parts.append("\n".join(meta_lines))
        parts.append("")

    summary = str(meta.get("summary") or "").strip()
    if summary:
        parts.append("## Summary\n")
        parts.append(summary)
        parts.append("")

    parts.append("## Content\n")
    return parts


def _detail_chunk_markdown_parts(
    hit: dict[str, Any],
    *,
    content_offset: int = 0,
) -> list[str]:
    src = hit.get("_source", {})
    text = str(src.get("content") or "").strip()
    if not text:
        return []
    if content_offset >= len(text):
        return []
    parts: list[str] = []
    heading = src.get("heading_hierarchy") or []
    if isinstance(heading, list) and heading:
        parts.append("### " + " / ".join(str(x) for x in heading if x))
    page = src.get("page_number")
    if page is not None:
        parts.append(f"> Page: {page}")
    parts.append(text[content_offset:])
    return parts


def _render_partial_first_chunk(
    header_parts: list[str],
    hit: dict[str, Any],
    max_chars: int,
    content_offset: int,
) -> tuple[str, int | None]:
    src = hit.get("_source", {})
    text = str(src.get("content") or "").strip()
    if content_offset >= len(text):
        return "\n".join(header_parts).strip(), None

    prefix_parts: list[str] = []
    heading = src.get("heading_hierarchy") or []
    if isinstance(heading, list) and heading:
        prefix_parts.append("### " + " / ".join(str(x) for x in heading if x))
    page = src.get("page_number")
    if page is not None:
        prefix_parts.append(f"> Page: {page}")

    marker = "\n\n...[content truncated]"
    base = "\n".join(header_parts + prefix_parts).strip()
    if base:
        base += "\n"
    budget = max_chars - len(base) - len(marker)
    if budget <= 0:
        rendered = (base[:max_chars] if base else "") + marker
        return rendered.strip(), content_offset

    remaining = text[content_offset:]
    slice_text = remaining[:budget]
    next_offset = content_offset + len(slice_text)
    rendered = (base + slice_text).rstrip()
    if next_offset < len(text):
        rendered += marker
        return rendered, next_offset
    return rendered, None


def _render_detail_markdown(
    meta: dict[str, Any],
    chunk_hits: list[dict[str, Any]],
    max_chars: int,
) -> str:
    content, _, _, _, _ = _render_detail_markdown_page(meta, chunk_hits, max_chars)
    return content
