"""SSRF + path safety guard for fetch URLs (COMP-03).

Two layers:

1. Path blacklist — refuses obvious admin/login/internal endpoints even
   when a yaml entry mistakenly points at one.

2. SSRF defense — restricts scheme to http/https; rejects URLs that
   resolve to private/loopback/link-local/multicast IP ranges (incl. the
   AWS / 阿里云 169.254.169.254 metadata service). Without this, a
   compromised yaml row or admin session could trick the api/scheduler
   into fetching internal services on the docker network or the host
   network. (Reviewer P1.)

Optional allowlist — when GOVCRAWLER_HOST_ALLOWLIST env is set to a
comma-separated list of suffixes (e.g. "gov.cn,news.cn,com.cn"), URLs
whose host doesn't end with any allowed suffix are rejected. Empty / unset
means "no allowlist", relying on the IP-range guard alone. The allowlist
is the recommended hardening for production.
"""
from __future__ import annotations

import ipaddress
import os
import socket
from functools import lru_cache
from urllib.parse import urlparse

FORBIDDEN_PATH_PATTERNS: tuple[str, ...] = (
    "/admin",
    "/api/internal",
    "/internal/",
    "/backstage/",
    "/manage/",
    "/login",
    "/logout",
    "/user/",
    "/account/",
)

ALLOWED_SCHEMES: frozenset[str] = frozenset({"http", "https"})

# Hosts that should never be reached even by IP literal — covers the cloud
# metadata services and "localhost" alias.
ALWAYS_DENY_HOSTS: frozenset[str] = frozenset({
    "localhost",
    "metadata",
    "metadata.google.internal",
    # 169.254.169.254 also covered by ipaddress.ip_address.is_link_local
})


def _path_ok(url: str) -> bool:
    path = urlparse(url).path.lower()
    return not any(frag in path for frag in FORBIDDEN_PATH_PATTERNS)


@lru_cache(maxsize=2048)
def _resolve_host(host: str) -> tuple[str, ...]:
    """DNS-resolve host → tuple of ip strings. LRU cached because the same
    host gets checked many times per crawl. Returns empty tuple if
    resolution fails (caller will treat as deny)."""
    try:
        infos = socket.getaddrinfo(host, None)
        return tuple({info[4][0] for info in infos})
    except Exception:
        return ()


def _ip_safe(ip_str: str) -> bool:
    try:
        ip = ipaddress.ip_address(ip_str)
    except ValueError:
        return False
    if ip.is_loopback or ip.is_private or ip.is_link_local:
        return False
    if ip.is_multicast or ip.is_unspecified or ip.is_reserved:
        return False
    return True


def _host_allowed_by_allowlist(host: str) -> bool:
    raw = os.environ.get("GOVCRAWLER_HOST_ALLOWLIST", "").strip()
    if not raw:
        return True  # allowlist unset → no extra restriction
    suffixes = tuple(
        s.strip().lower().lstrip(".")
        for s in raw.split(",")
        if s.strip()
    )
    h = host.lower().rstrip(".")
    return any(h == suf or h.endswith("." + suf) for suf in suffixes)


def is_public_path(url: str) -> bool:
    """Path-level + scheme guard (COMP-03).

    Returns False if any of:
      • scheme is not http/https
      • path matches a forbidden fragment (admin, login, internal …)

    Does NOT do DNS / IP-range checks — those live in is_safe_to_fetch()
    so unit tests using fake hosts ("https://x/...") still exercise this
    function as a pure path predicate.
    """
    parsed = urlparse(url)
    if parsed.scheme.lower() not in ALLOWED_SCHEMES:
        return False
    return _path_ok(url)


def is_safe_to_fetch(url: str) -> bool:
    """Full SSRF guard. Adds DNS + IP-range + allowlist checks on top of
    is_public_path. Used by the live pipeline before any outbound fetch.

    A URL is considered safe when ALL of the following hold:
      • is_public_path(url) is True
      • host is not in the always-deny list (localhost, cloud metadata)
      • all DNS-resolved IPs are public (not loopback/private/link-local)
      • when GOVCRAWLER_HOST_ALLOWLIST is set, host matches an entry

    DNS resolution failures fail closed — better to refuse than let httpx
    swallow a confusing error 20s later.
    """
    if not is_public_path(url):
        return False

    parsed = urlparse(url)
    host = (parsed.hostname or "").lower()
    if not host:
        return False
    if host in ALWAYS_DENY_HOSTS:
        return False
    if not _host_allowed_by_allowlist(host):
        return False

    try:
        ipaddress.ip_address(host)
        return _ip_safe(host)
    except ValueError:
        pass

    ips = _resolve_host(host)
    if not ips:
        return False
    return all(_ip_safe(ip) for ip in ips)
