from __future__ import annotations
import random
import time
from urllib.parse import urlparse

DEFAULT_INTERVAL_S = 10.0
# Positive-only jitter: actual interval is always >= configured base,
# never shorter. Random uniform(0, jitter_pct) is added on top so a
# "60 second" interval becomes "60..78 seconds" with the default 0.30.
# WAFs frequently pattern-match on perfectly periodic request streams
# ("every 60.0s exactly = bot"); spreading the gap out helps mimic a
# human browse cadence. Keep ratio modest by default; ops can override
# per-site if a target keeps tripping.
JITTER_PCT = 0.30


class HostThrottle:
    """Per-host minimum interval gate. In-process only (Phase 2 will use Valkey)."""

    def __init__(
        self,
        interval_s: float | None = DEFAULT_INTERVAL_S,
        jitter_s: float | None = None,
        jitter_pct: float = JITTER_PCT,
    ):
        # Accept None explicitly — callers routinely pass `crawl_target.interval_sec`
        # which is nullable. Falling back to DEFAULT_INTERVAL_S avoids silent
        # TypeError("unsupported operand type(s) for *: 'NoneType' and 'float'")
        # on the second wait() call (first call skips math because last is None,
        # hiding the bug until an actual rate-limit hits).
        self.interval_s = DEFAULT_INTERVAL_S if interval_s is None else interval_s
        self.jitter_s = None if jitter_s is None else max(0.0, float(jitter_s))
        self.jitter_pct = jitter_pct
        self._last_by_host: dict[str, float] = {}

    def _host(self, url: str) -> str:
        return urlparse(url).netloc.lower()

    def wait(self, url: str, *, sleep=time.sleep, now=time.monotonic) -> float:
        host = self._host(url)
        now_t = now()
        last = self._last_by_host.get(host)
        sleep_s = 0.0
        if last is not None:
            # Bias positive: never shorter than the configured base.
            # If ops configured an absolute jitter window, use it directly;
            # otherwise retain the historic percent-based default.
            if self.jitter_s is not None:
                interval = self.interval_s + random.uniform(0.0, self.jitter_s)
            else:
                jitter = 1.0 + random.uniform(0.0, self.jitter_pct)
                interval = self.interval_s * jitter
            due = last + interval
            sleep_s = max(0.0, due - now_t)
            if sleep_s > 0:
                sleep(sleep_s)
        self._last_by_host[host] = now()
        return sleep_s
