"""FastAPI middleware that records every state-changing admin call into
the `admin_audit_log` table.

Scope:
  • POST / PUT / DELETE / PATCH on /admin/api/*    (admin actions)
  • POST on /api/articles/.../ack                  (RAG export ack)

GET requests are NOT logged — reads don't change state and would dominate
volume.

The middleware is BEST-EFFORT: a DB write failure here must not break the
underlying request. We catch + log a warning and let the response through.

Body bytes are read once, hashed (sha256[:16]), and re-injected into the
ASGI receive channel so downstream handlers can still parse the body.

Action / resource_type / resource_id are derived heuristically from
method + path. The classifier is intentionally simple and covers the
endpoints we currently expose; new endpoints just fall through to the
generic "admin.<method>" action.
"""
from __future__ import annotations

import base64
import hashlib
import logging
import re
import time

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response

log = logging.getLogger(__name__)


# ---- path → (action, resource_type, resource_id) classifier --------------

_PATH_RULES: list[tuple[re.Pattern[str], str, str | None]] = [
    # admin: sites
    (re.compile(r"^/admin/api/sites$"), "site.create", "site"),
    (re.compile(r"^/admin/api/sites/(?P<id>[^/]+)$"), "site.update", "site"),
    (re.compile(r"^/admin/api/sites/(?P<id>[^/]+)/toggle$"), "site.toggle", "site"),
    # admin: targets
    (re.compile(r"^/admin/api/targets$"), "target.create", "target"),
    (re.compile(r"^/admin/api/targets/bulk-create"), "target.bulk_create", "target"),
    (re.compile(r"^/admin/api/targets/bulk-create-yaml"), "target.bulk_create_yaml", "target"),
    (re.compile(r"^/admin/api/targets/discover-html"), "target.discover_html", "target"),
    (re.compile(r"^/admin/api/targets/bulk-run"), "target.bulk_run", "target"),
    (re.compile(r"^/admin/api/targets/(?P<id>[^/]+)/run$"), "target.run", "target"),
    (re.compile(r"^/admin/api/targets/(?P<id>[^/]+)/toggle$"), "target.toggle", "target"),
    (re.compile(r"^/admin/api/targets/(?P<id>[^/]+)/parser$"), "target.update_parser", "target"),
    (re.compile(r"^/admin/api/targets/(?P<id>[^/]+)$"), "target.update", "target"),
    # admin: articles
    (re.compile(r"^/admin/api/articles/bulk-delete$"), "article.bulk_delete", "article"),
    (re.compile(r"^/admin/api/articles/(?P<id>\d+)$"), "article.delete", "article"),
    # admin: departments
    (re.compile(r"^/admin/api/local-departments$"), "dept.create", "department"),
    (re.compile(r"^/admin/api/local-departments/(?P<id>\d+)$"), "dept.update", "department"),
    # admin: jobs
    (re.compile(r"^/admin/api/jobs/(?P<id>[^/]+)/cancel$"), "job.cancel", "job"),
    # admin: logs retry
    (re.compile(r"^/admin/api/logs/(?P<id>\d+)/retry$"), "log.retry", "log"),
    # public: rag ack
    (re.compile(r"^/api/articles/(?P<id>\d+)/ack$"), "article.ack", "article"),
]


def _classify(method: str, path: str) -> tuple[str, str | None, str | None]:
    """Return (action, resource_type, resource_id) for a logged request."""
    for pat, action, rtype in _PATH_RULES:
        m = pat.match(path)
        if m:
            rid = m.groupdict().get("id") if hasattr(m, "groupdict") else None
            return action, rtype, rid
    # Generic fallback so audit row still has a non-null action class.
    return f"admin.{method.lower()}", None, None


def _client_ip(request: Request) -> str | None:
    """Prefer X-Forwarded-For (we run behind nginx in compose-prod).
    Falls back to socket peer."""
    fwd = request.headers.get("x-forwarded-for")
    if fwd:
        # First entry is the original client; rest are proxies.
        return fwd.split(",", 1)[0].strip()
    return request.client.host if request.client else None


def _actor_from_basic_auth(request: Request) -> str | None:
    auth = request.headers.get("authorization", "")
    if not auth.lower().startswith("basic "):
        return None
    try:
        raw = base64.b64decode(auth.split(" ", 1)[1]).decode("utf-8", "replace")
        return raw.partition(":")[0] or None
    except Exception:
        return None


def _should_log(method: str, path: str) -> bool:
    if method.upper() not in ("POST", "PUT", "DELETE", "PATCH"):
        return False
    if path.startswith("/admin/api/"):
        return True
    # Public RAG ack write.
    return path.startswith("/api/articles/") and path.endswith("/ack")


class AuditMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        method = request.method
        path = request.url.path
        if not _should_log(method, path):
            return await call_next(request)

        # Read body once + hash, re-inject so downstream parsers still see it.
        body = await request.body()
        digest = (
            hashlib.sha256(body).hexdigest()[:16] if body else ""
        )

        async def _receive():
            return {"type": "http.request", "body": body, "more_body": False}

        # Replace the receive callable so downstream handlers re-read body.
        request._receive = _receive  # type: ignore[attr-defined]

        actor = _actor_from_basic_auth(request)
        actor_ip = _client_ip(request)
        action, rtype, rid = _classify(method, path)

        t0 = time.time()
        response: Response | None = None
        try:
            response = await call_next(request)
            return response
        finally:
            duration_ms = int((time.time() - t0) * 1000)
            status_code = response.status_code if response is not None else 0
            try:
                # Lazy import to keep cold-start path light.
                from datetime import datetime
                from govcrawler.db import get_sessionmaker
                from govcrawler.models import AdminAuditLog
                S = get_sessionmaker()
                with S() as s:
                    s.add(AdminAuditLog(
                        created_at=datetime.utcnow(),
                        actor=actor[:64] if actor else None,
                        actor_ip=actor_ip[:64] if actor_ip else None,
                        method=method[:8],
                        path=path[:500],
                        status_code=status_code,
                        duration_ms=duration_ms,
                        payload_digest=digest,
                        action=action[:64] if action else None,
                        resource_type=rtype,
                        resource_id=str(rid)[:128] if rid is not None else None,
                    ))
                    s.commit()
            except Exception:
                # Never let an audit failure break the underlying request.
                log.exception(
                    "audit write failed actor=%s method=%s path=%s",
                    actor, method, path,
                )
