"""Auto-fallback fetcher chain: httpx (Tier 2) → playwright (Tier 3).

Per design doc §4.1: default Tier 2; on 412/403/empty/challenge → Tier 3.
Tier 4 (DrissionPage) not yet wired — add when Tier 3 also fails in production.

Two operational hardenings layered on top:
  1. Host cooldown — if a host's connection-reset count crosses a threshold,
     short-circuit both Tier 2 and Tier 3 with an immediate error for the
     next COOLDOWN_SECONDS. This stopped the 1831 reset-storms we saw on
     gd_gkmlpt — repeatedly hammering a WAF after it's already cold-shouldered
     us just deepens the lockout.
  2. Network-layer errors don't trigger Tier 3 — TCP RST / DNS fail / connect
     timeout are blockers a headless browser can't bypass, retrying just
     wastes 30s of playwright nav timeout per article.
"""
from __future__ import annotations

import logging
import threading
import time
from collections import defaultdict
from urllib.parse import urlparse

from govcrawler.cookies import get_default_store
from govcrawler.fetcher.browser import FetchResult, fetch_html as fetch_html_browser
from govcrawler.fetcher.http_client import fetch_html_http

log = logging.getLogger(__name__)

MIN_HTML_CHARS = 500  # if Tier 2 returns less, assume JS-rendered and fall back

# Network-layer error class names — these are TCP/DNS-level, not application
# rejections. A headless browser can't fix them; falling back to Tier 3 just
# reproduces the same RST/timeout 30s later. Keys are httpx exception type
# names (matched against the `error=f"{type(e).__name__}: ..."` prefix the
# http_client builds).
_NETWORK_ERROR_PREFIXES = (
    "ReadError:",          # TCP RST mid-read (the gd_gkmlpt failure mode)
    "RemoteProtocolError:",  # server hung up unexpectedly
    "ConnectError:",       # can't establish TCP
    "ConnectTimeout:",     # SYN timeout
    "ReadTimeout:",        # read timed out (server stalled, not a challenge)
)

# ---------------------------------------------------------------------------
# Per-host failure tracker + cooldown
# ---------------------------------------------------------------------------
# Trip after 3 connection-reset failures in a row → 10 min cooldown for that
# host. Successful fetch resets the counter. Tunable via env if a site
# behaves differently.
_FAIL_THRESHOLD = 3
_COOLDOWN_SECONDS = 600

_host_lock = threading.Lock()
_host_fails: dict[str, int] = defaultdict(int)
_host_cooldown_until: dict[str, float] = {}


def _is_network_error(fr: FetchResult) -> bool:
    if not fr.error:
        return False
    return any(fr.error.startswith(p) for p in _NETWORK_ERROR_PREFIXES)


def _record_host_outcome(host: str, fr: FetchResult) -> None:
    """Update the per-host failure counter. Connection-reset / timeout
    increments; any successful fetch resets to 0."""
    with _host_lock:
        if fr.error and _is_network_error(fr):
            _host_fails[host] += 1
            if _host_fails[host] >= _FAIL_THRESHOLD:
                _host_cooldown_until[host] = time.time() + _COOLDOWN_SECONDS
                log.warning(
                    "host cooldown: %s — %d consecutive network errors, "
                    "skipping for %ds",
                    host, _host_fails[host], _COOLDOWN_SECONDS,
                )
        else:
            if _host_fails.get(host, 0) > 0:
                log.info("host %s recovered — clearing fail counter", host)
            _host_fails[host] = 0
            _host_cooldown_until.pop(host, None)


def _check_host_cooldown(host: str) -> FetchResult | None:
    """Return a fast-fail FetchResult if `host` is in active cooldown; None
    otherwise. Caller should treat the returned FetchResult as the final
    result and skip both Tier 2 and Tier 3."""
    with _host_lock:
        until = _host_cooldown_until.get(host)
        if until is None or until <= time.time():
            if until is not None:
                # Cooldown expired — let the next fetch try again.
                _host_cooldown_until.pop(host, None)
                _host_fails[host] = 0
            return None
        remaining = int(until - time.time())
    return FetchResult(
        url="",
        final_url="",
        status=0,
        html="",
        fetched_at=time.time(),
        duration_ms=0,
        is_challenge=False,
        error=f"host_cooldown: {host} blocked for {remaining}s more (WAF protection)",
        strategy="cooldown",
    )


def _should_fallback(fr: FetchResult) -> bool:
    if fr.error:
        # Only fall through to playwright if the failure is application-level
        # (HTTP status / challenge / empty body). Network-layer errors won't
        # be helped by switching to a browser.
        return not _is_network_error(fr)
    if fr.status in (403, 412, 429) or fr.status >= 500:
        return True
    if fr.is_challenge:
        return True
    if not fr.html or len(fr.html) < MIN_HTML_CHARS:
        return True
    return False


def fetch_html(url: str, *, force_browser: bool = False) -> FetchResult:
    """Try Tier 2 (httpx) first; fall back to Tier 3 (playwright) on failure.

    Set `force_browser=True` to skip Tier 2 entirely (e.g. known-challenge domains).
    """
    host = urlparse(url).netloc.lower()

    # Hot-path: if this host is in cooldown, fail fast — don't spend
    # ~20s on httpx timeout + ~30s on playwright nav timeout.
    skip = _check_host_cooldown(host)
    if skip is not None:
        skip.url = url
        skip.final_url = url
        return skip

    if force_browser:
        fr = fetch_html_browser(url)
        _record_host_outcome(host, fr)
        return fr

    fr = fetch_html_http(url)
    if not _should_fallback(fr):
        _record_host_outcome(host, fr)
        return fr

    # FETCH-04: Tier 2 failed with injected cookies → they're stale/invalid.
    # Drop the pool entry so playwright re-primes from a clean slate.
    try:
        get_default_store().invalidate(host)
    except Exception:
        pass

    # Tier 3 fallback; preserve original Tier 2 failure info in error if browser also fails
    fr2 = fetch_html_browser(url)
    if fr2.error and fr.error:
        fr2.error = f"httpx:{fr.error} | playwright:{fr2.error}"
    _record_host_outcome(host, fr2)
    return fr2
