"""Shared FastAPI dependencies (DI) for request handlers.

公共依赖注入模块。
定义 JWT 认证、用户上下文解析，以及各基础设施客户端（ES、Neo4j、Redis、
Embedding、LLM）的依赖获取函数，供所有路由处理器通过 FastAPI Depends 使用。
"""

from __future__ import annotations

from typing import Annotated, TYPE_CHECKING

from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError, jwt
from pydantic import BaseModel

from app.config import settings

if TYPE_CHECKING:
    from app.core.permission import PermissionContext
    from app.infrastructure.embedding_client import EmbeddingClient
    from app.infrastructure.es_client import ESClient
    from app.infrastructure.llm_client import LLMClient
    from app.infrastructure.mysql_client import MySQLClient
    from app.infrastructure.neo4j_client import Neo4jClient
    from app.infrastructure.redis_client import RedisClient

# Re-usable auth scheme
_bearer_scheme = HTTPBearer(auto_error=True)
_optional_bearer_scheme = HTTPBearer(auto_error=False)


# ── User context model ───────────────────────────────────────────────────────


class UserContext(BaseModel):
    """User identity extracted from a JWT bearer token.

    JWT payload::

        {
            "sub": "<user_id>",
            "office_id": "O_17",
            "dept_id": "D_05",
            "area_id": "A_01",
            "role_ids": ["R_03"]
        }

    从 JWT 中提取的用户身份信息，贯穿整个请求生命周期，
    用于权限过滤和操作审计。
    """

    user_id: str
    office_id: str = ""
    dept_id: str = ""
    area_id: str = ""
    role_ids: list[str] = []


# ── Infrastructure getters ───────────────────────────────────────────────────
# These resolve to the singleton clients attached to ``app.state`` during the
# lifespan startup phase.


def get_es_client(request: Request) -> "ESClient":
    """Yield the shared :class:`ESClient` instance.

    返回共享的 ES 客户端单例。
    """
    return request.app.state.es_client


def get_neo4j_client(request: Request) -> "Neo4jClient":
    """Yield the shared :class:`Neo4jClient` instance.

    返回共享的 Neo4j 客户端单例。
    """
    return request.app.state.neo4j_client


def get_redis_client(request: Request) -> "RedisClient":
    """Yield the shared :class:`RedisClient` instance.

    返回共享的 Redis 客户端单例。
    """
    return request.app.state.redis_client


def get_mysql_client(request: Request) -> "MySQLClient | None":
    """Yield the shared MySQL client instance, if configured.

    返回共享的 MySQL 客户端单例；未配置时返回 None。
    """
    return getattr(request.app.state, "mysql_client", None)


def get_embedding_client(request: Request) -> "EmbeddingClient":
    """Yield the shared :class:`EmbeddingClient` instance.

    返回共享的 Embedding 客户端单例。
    """
    return request.app.state.embedding_client


def get_llm_client(request: Request) -> "LLMClient":
    """Yield the shared :class:`LLMClient` instance.

    返回共享的 LLM 客户端单例。
    """
    return request.app.state.llm_client


# ── JWT / auth dependency ────────────────────────────────────────────────────


async def get_current_user(
    credentials: Annotated[HTTPAuthorizationCredentials, Depends(_bearer_scheme)],
) -> UserContext:
    """Decode and validate the JWT, returning a :class:`UserContext`.

    解码并验证 Bearer Token，提取用户身份信息。
    验证失败时返回 401 Unauthorized。
    """
    token = credentials.credentials
    try:
        payload = jwt.decode(
            token,
            settings.jwt_secret,
            algorithms=[settings.jwt_algorithm],
        )
    except JWTError as exc:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=f"Invalid or expired token: {exc}",
            headers={"WWW-Authenticate": "Bearer"},
        ) from exc

    user_id: str | None = payload.get("sub")
    if not user_id:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Token missing 'sub' claim",
            headers={"WWW-Authenticate": "Bearer"},
        )

    return UserContext(
        user_id=user_id,
        office_id=payload.get("office_id", ""),
        dept_id=payload.get("dept_id", ""),
        area_id=payload.get("area_id", ""),
        role_ids=payload.get("role_ids", []),
    )


async def get_optional_current_user(
    credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(_optional_bearer_scheme)],
) -> UserContext | None:
    """Decode JWT when present, otherwise allow anonymous access."""
    if credentials is None:
        return None
    token = credentials.credentials
    try:
        payload = jwt.decode(
            token,
            settings.jwt_secret,
            algorithms=[settings.jwt_algorithm],
        )
    except JWTError as exc:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=f"Invalid or expired token: {exc}",
            headers={"WWW-Authenticate": "Bearer"},
        ) from exc

    user_id: str | None = payload.get("sub")
    if not user_id:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Token missing 'sub' claim",
            headers={"WWW-Authenticate": "Bearer"},
        )

    return UserContext(
        user_id=user_id,
        office_id=payload.get("office_id", ""),
        dept_id=payload.get("dept_id", ""),
        area_id=payload.get("area_id", ""),
        role_ids=payload.get("role_ids", []),
    )


async def get_permission_context(
    request: Request,
    user: Annotated[UserContext, Depends(get_current_user)],
) -> "PermissionContext":
    """Resolve the request user's permission context for ACL-aware handlers."""
    from app.core.permission import PermissionService

    redis_client = request.app.state.redis_client
    service = PermissionService(redis_client=redis_client)
    return await service.resolve(user)


async def get_optional_permission_context(
    request: Request,
    user: Annotated[UserContext | None, Depends(get_optional_current_user)],
) -> "PermissionContext | None":
    """Resolve permission context when JWT is present, otherwise return None."""
    if user is None:
        return None

    from app.core.permission import PermissionService

    redis_client = request.app.state.redis_client
    service = PermissionService(redis_client=redis_client)
    return await service.resolve(user)
