"""Fixed-window rate-limiting middleware backed by Redis.

Strategy
--------
For each (path_prefix, identifier) pair, we maintain a Redis key with a
fixed-window counter (按时间窗口分桶计数).  The identifier is the JWT
``sub`` claim when available, otherwise the client IP address.

Limits (requests / window_seconds):
    POST /api/ai/v1/research   →  10 / 60   (LLM-heavy endpoint)
    GET  /api/ai/v1/search     →  60 / 60
    POST /api/ai/v1/ingest     → 200 / 60   (machine-to-machine webhook)
    *  (default)               → 200 / 60

Configuration can be overridden via environment variables:
    RATE_LIMIT_RESEARCH   = "10/60"
    RATE_LIMIT_SEARCH     = "60/60"
    RATE_LIMIT_DEFAULT    = "200/60"
    RATE_LIMIT_ENABLED    = "true"

基于 Redis 固定窗口的请求限流中间件。
按 (路径前缀, 用户标识) 维度计数，超过阈值返回 429。
采用 fail-open 策略：Redis 不可用时放行请求，避免因缓存故障阻塞业务。
"""

from __future__ import annotations

import time
from typing import Callable

from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp

from app.config import settings
from app.utils.logger import get_logger

logger = get_logger(__name__)

# ---------------------------------------------------------------------------
# Rate limit table: (path_prefix, max_requests, window_seconds)
# ---------------------------------------------------------------------------

_LIMITS: list[tuple[str, int, int]] = [
    ("/api/ai/v1/research", settings.rate_limit_research_max, settings.rate_limit_research_window),
    ("/api/ai/v1/search", settings.rate_limit_search_max, settings.rate_limit_search_window),
    ("/api/ai/v1/ingest", 200, 60),
]
_DEFAULT_LIMIT = (settings.rate_limit_default_max, settings.rate_limit_default_window)


def _get_limit(path: str) -> tuple[int, int]:
    for prefix, max_req, window in _LIMITS:
        if path.startswith(prefix):
            return max_req, window
    return _DEFAULT_LIMIT


def _extract_identifier(request: Request) -> str:
    """Extract per-client identifier: JWT sub or client IP."""
    auth_header = request.headers.get("Authorization", "")
    if auth_header.startswith("Bearer "):
        token = auth_header[7:]
        try:
            # Quick decode without verification just for the sub claim
            import base64, json as _json
            parts = token.split(".")
            if len(parts) == 3:
                # 修复 Base64 padding：根据长度动态补齐 '='，而非固定追加 '=='
                payload_b64 = parts[1] + "=" * (-len(parts[1]) % 4)
                payload = _json.loads(base64.urlsafe_b64decode(payload_b64))
                if "sub" in payload:
                    return f"user:{payload['sub']}"
        except Exception:
            pass

    # Fall back to client IP
    forwarded_for = request.headers.get("X-Forwarded-For", "")
    if forwarded_for:
        return f"ip:{forwarded_for.split(',')[0].strip()}"
    client = request.client
    return f"ip:{client.host if client else 'unknown'}"


class RateLimitMiddleware(BaseHTTPMiddleware):
    """Fixed-window rate limiting via Redis atomic INCR + EXPIRE.

    Skips rate limiting if Redis is unavailable (fail-open).

    固定窗口限流中间件，使用 Redis INCR + EXPIRE 原子操作实现计数。
    仅对 /api/ai/ 路径生效，Redis 故障时自动放行（fail-open）。
    """

    def __init__(self, app: ASGIApp) -> None:
        super().__init__(app)

    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        if not settings.rate_limit_enabled:
            return await call_next(request)

        # Only rate-limit the AI API paths
        path = request.url.path
        if not path.startswith("/api/ai/"):
            return await call_next(request)

        # Health check bypass
        if path == "/health":
            return await call_next(request)

        try:
            redis = request.app.state.redis_client.raw
            max_req, window = _get_limit(path)
            identifier = _extract_identifier(request)

            # 固定窗口 key：按时间窗口长度分桶
            bucket = int(time.time()) // window
            key = f"rl:{path.split('/')[4] if path.count('/') >= 4 else 'api'}:{identifier}:{bucket}"

            current = await redis.incr(key)
            if current == 1:
                await redis.expire(key, window * 2)  # 2x window to handle clock skew

            if current > max_req:
                logger.warning(
                    "rate_limit_exceeded",
                    path=path,
                    identifier=identifier,
                    count=current,
                    limit=max_req,
                )
                return JSONResponse(
                    status_code=429,
                    content={
                        "detail": f"请求频率超出限制，请 {window} 秒后重试。",
                        "limit": max_req,
                        "window_seconds": window,
                    },
                    headers={"Retry-After": str(window)},
                )

            response = await call_next(request)
            response.headers["X-RateLimit-Limit"] = str(max_req)
            response.headers["X-RateLimit-Remaining"] = str(max(0, max_req - current))
            return response

        except Exception as exc:
            # Fail open — don't block requests if Redis is down
            logger.warning("rate_limit_redis_error", error=str(exc))
            return await call_next(request)
