"""
Max Iterations Limiter for LiteLLM Proxy.

Enforces a per-session cap on the number of LLM calls an agentic loop can make.
Callers send a `session_id` with each request (via `x-litellm-session-id` header
or `metadata.session_id`), and this hook counts calls per session. When the count
exceeds `max_iterations` (configured in agent litellm_params or key metadata), returns 429.

Works across multiple proxy instances via DualCache (in-memory + Redis).
Follows the same pattern as parallel_request_limiter_v3.py.
"""

import os
from typing import TYPE_CHECKING, Any, Optional, Union

from fastapi import HTTPException

from litellm import DualCache
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth

if TYPE_CHECKING:
    from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache

    InternalUsageCache = _InternalUsageCache
else:
    InternalUsageCache = Any


# Redis Lua script for atomic increment with TTL.
# Returns the new count after increment.
# Only sets EXPIRE on first increment (when count becomes 1).
MAX_ITERATIONS_INCREMENT_SCRIPT = """
local key = KEYS[1]
local ttl = tonumber(ARGV[1])

local current = redis.call('INCR', key)
if current == 1 then
    redis.call('EXPIRE', key, ttl)
end

return current
"""

# Default TTL for session iteration counters (1 hour)
DEFAULT_MAX_ITERATIONS_TTL = 3600


class _PROXY_MaxIterationsHandler(CustomLogger):
    """
    Pre-call hook that enforces max_iterations per session.

    Configuration:
        - max_iterations: set in agent litellm_params (preferred)
          e.g. litellm_params={"max_iterations": 25}
          Falls back to key metadata max_iterations for backwards compatibility.
        - session_id: sent by caller via x-litellm-session-id header or
          metadata.session_id in request body

    Cache key pattern:
        {session_iterations:<session_id>}:count

    Multi-instance support:
        Uses Redis Lua script for atomic increment (same pattern as
        parallel_request_limiter_v3). Falls back to in-memory cache
        when Redis is unavailable.
    """

    def __init__(self, internal_usage_cache: InternalUsageCache):
        self.internal_usage_cache = internal_usage_cache
        self.ttl = int(
            os.getenv("LITELLM_MAX_ITERATIONS_TTL", DEFAULT_MAX_ITERATIONS_TTL)
        )

        # Register Lua script with Redis if available (same pattern as v3 limiter)
        if self.internal_usage_cache.dual_cache.redis_cache is not None:
            self.increment_script = (
                self.internal_usage_cache.dual_cache.redis_cache.async_register_script(
                    MAX_ITERATIONS_INCREMENT_SCRIPT
                )
            )
        else:
            self.increment_script = None

    async def async_pre_call_hook(
        self,
        user_api_key_dict: UserAPIKeyAuth,
        cache: DualCache,
        data: dict,
        call_type: str,
    ) -> Optional[Union[Exception, str, dict]]:
        """
        Check session iteration count before making the API call.

        Extracts session_id from request metadata and max_iterations from
        agent litellm_params. If the session has exceeded max_iterations, raises 429.
        """
        # Extract session_id from request data
        session_id = self._get_session_id(data)
        if session_id is None:
            return None

        max_iterations = self._get_max_iterations(user_api_key_dict)
        if max_iterations is None:
            return None

        verbose_proxy_logger.debug(
            "MaxIterationsHandler: session_id=%s, max_iterations=%s",
            session_id,
            max_iterations,
        )

        # Increment and check
        cache_key = self._make_cache_key(session_id)
        current_count = await self._increment_and_get(cache_key)

        if current_count > max_iterations:
            raise HTTPException(
                status_code=429,
                detail=(
                    f"Max iterations exceeded for session {session_id}. "
                    f"Current count: {current_count}, max_iterations: {max_iterations}."
                ),
            )

        verbose_proxy_logger.debug(
            "MaxIterationsHandler: session_id=%s, count=%s/%s",
            session_id,
            current_count,
            max_iterations,
        )

        return None

    def _get_session_id(self, data: dict) -> Optional[str]:
        """Extract session_id from request metadata."""
        metadata = data.get("metadata") or {}
        session_id = metadata.get("session_id")
        if session_id is not None:
            return str(session_id)

        # Also check litellm_metadata (used for /thread and /assistant endpoints)
        litellm_metadata = data.get("litellm_metadata") or {}
        session_id = litellm_metadata.get("session_id")
        if session_id is not None:
            return str(session_id)

        return None

    def _get_max_iterations(
        self, user_api_key_dict: UserAPIKeyAuth
    ) -> Optional[int]:
        """Extract max_iterations from agent litellm_params, with fallback to key metadata."""
        # Try agent litellm_params first
        agent_id = user_api_key_dict.agent_id
        if agent_id is not None:
            from litellm.proxy.agent_endpoints.agent_registry import (
                global_agent_registry,
            )

            agent = global_agent_registry.get_agent_by_id(agent_id=agent_id)
            if agent is not None:
                litellm_params = agent.litellm_params or {}
                max_iterations = litellm_params.get("max_iterations")
                if max_iterations is not None:
                    return int(max_iterations)

        # Fallback to key metadata for backwards compatibility
        metadata = user_api_key_dict.metadata or {}
        max_iterations = metadata.get("max_iterations")
        if max_iterations is not None:
            return int(max_iterations)
        return None

    def _make_cache_key(self, session_id: str) -> str:
        """
        Create cache key for session iteration counter.

        Uses Redis hash-tag pattern {session_iterations:<session_id>} so all
        keys for a session land on the same Redis Cluster slot.
        """
        return f"{{session_iterations:{session_id}}}:count"

    async def _increment_and_get(self, cache_key: str) -> int:
        """
        Atomically increment the session counter and return the new value.

        Tries Redis first (via registered Lua script for atomicity across
        instances), falls back to in-memory cache.
        """
        if self.increment_script is not None:
            try:
                result = await self.increment_script(
                    keys=[cache_key],
                    args=[self.ttl],
                )
                return int(result)
            except Exception as e:
                verbose_proxy_logger.warning(
                    "MaxIterationsHandler: Redis failed, falling back to in-memory: %s",
                    str(e),
                )

        # Fallback: in-memory cache
        return await self._in_memory_increment(cache_key)

    async def _in_memory_increment(self, cache_key: str) -> int:
        """Increment counter in in-memory cache with TTL."""
        current = await self.internal_usage_cache.async_get_cache(
            key=cache_key,
            litellm_parent_otel_span=None,
            local_only=True,
        )
        new_value = (int(current) if current is not None else 0) + 1
        await self.internal_usage_cache.async_set_cache(
            key=cache_key,
            value=new_value,
            ttl=self.ttl,
            litellm_parent_otel_span=None,
            local_only=True,
        )
        return new_value
