# What is this?
## Allocates dynamic tpm/rpm quota for a project based on current traffic
## Tracks num active projects per minute

import asyncio
import os
from typing import List, Optional, Tuple, Union

from fastapi import HTTPException

import litellm
from litellm import ModelResponse, Router
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.router import ModelGroupInfo
from litellm.types.utils import CallTypesLiteral
from litellm.utils import get_utc_datetime

from .rate_limiter_utils import convert_priority_to_percent


class DynamicRateLimiterCache:
    """
    Thin wrapper on DualCache for this file.

    Track number of active projects calling a model.
    """

    def __init__(self, cache: DualCache) -> None:
        self.cache = cache
        self.ttl = 60  # 1 min ttl

    async def async_get_cache(self, model: str) -> Optional[int]:
        dt = get_utc_datetime()
        current_minute = dt.strftime("%H-%M")
        key_name = "{}:{}".format(current_minute, model)
        _response = await self.cache.async_get_cache(key=key_name)
        response: Optional[int] = None
        if _response is not None:
            response = len(_response)
        return response

    async def async_set_cache_sadd(self, model: str, value: List):
        """
        Add value to set.

        Parameters:
        - model: str, the name of the model group
        - value: str, the team id

        Returns:
        - None

        Raises:
        - Exception, if unable to connect to cache client (if redis caching enabled)
        """
        try:
            dt = get_utc_datetime()
            current_minute = dt.strftime("%H-%M")

            key_name = "{}:{}".format(current_minute, model)
            await self.cache.async_set_cache_sadd(
                key=key_name, value=value, ttl=self.ttl
            )
        except Exception as e:
            verbose_proxy_logger.exception(
                "litellm.proxy.hooks.dynamic_rate_limiter.py::async_set_cache_sadd(): Exception occured - {}".format(
                    str(e)
                )
            )
            raise e


class _PROXY_DynamicRateLimitHandler(CustomLogger):
    # Class variables or attributes
    def __init__(self, internal_usage_cache: DualCache):
        self.internal_usage_cache = DynamicRateLimiterCache(cache=internal_usage_cache)

    def update_variables(self, llm_router: Router):
        self.llm_router = llm_router

    async def check_available_usage(
        self, model: str, priority: Optional[str] = None
    ) -> Tuple[
        Optional[int], Optional[int], Optional[int], Optional[int], Optional[int]
    ]:
        """
        For a given model, get its available tpm

        Params:
        - model: str, the name of the model in the router model_list
        - priority: Optional[str], the priority for the request.

        Returns
        - Tuple[available_tpm, available_tpm, model_tpm, model_rpm, active_projects]
            - available_tpm: int or null - always 0 or positive.
            - available_tpm: int or null - always 0 or positive.
            - remaining_model_tpm: int or null. If available tpm is int, then this will be too.
            - remaining_model_rpm: int or null. If available rpm is int, then this will be too.
            - active_projects: int or null
        """
        try:
            # Get model info first for conversion
            model_group_info: Optional[ModelGroupInfo] = (
                self.llm_router.get_model_group_info(model_group=model)
            )

            weight: float = 1
            if (
                litellm.priority_reservation is None
                or priority not in litellm.priority_reservation
            ):
                verbose_proxy_logger.error(
                    "Priority Reservation not set. priority={}, but litellm.priority_reservation is {}.".format(
                        priority, litellm.priority_reservation
                    )
                )
            elif priority is not None and litellm.priority_reservation is not None:
                if os.getenv("LITELLM_LICENSE", None) is None:
                    verbose_proxy_logger.error(
                        "PREMIUM FEATURE: Reserving tpm/rpm by priority is a premium feature. Please add a 'LITELLM_LICENSE' to your .env to enable this.\nGet a license: https://docs.litellm.ai/docs/proxy/enterprise."
                    )
                else:
                    value = litellm.priority_reservation[priority]
                    weight = convert_priority_to_percent(value, model_group_info)

            active_projects = await self.internal_usage_cache.async_get_cache(
                model=model
            )
            (
                current_model_tpm,
                current_model_rpm,
            ) = await self.llm_router.get_model_group_usage(model_group=model)
            total_model_tpm: Optional[int] = None
            total_model_rpm: Optional[int] = None
            if model_group_info is not None:
                if model_group_info.tpm is not None:
                    total_model_tpm = model_group_info.tpm
                if model_group_info.rpm is not None:
                    total_model_rpm = model_group_info.rpm

            remaining_model_tpm: Optional[int] = None
            if total_model_tpm is not None and current_model_tpm is not None:
                remaining_model_tpm = total_model_tpm - current_model_tpm
            elif total_model_tpm is not None:
                remaining_model_tpm = total_model_tpm

            remaining_model_rpm: Optional[int] = None
            if total_model_rpm is not None and current_model_rpm is not None:
                remaining_model_rpm = total_model_rpm - current_model_rpm
            elif total_model_rpm is not None:
                remaining_model_rpm = total_model_rpm

            available_tpm: Optional[int] = None

            if remaining_model_tpm is not None:
                if active_projects is not None:
                    available_tpm = int(remaining_model_tpm * weight / active_projects)
                else:
                    available_tpm = int(remaining_model_tpm * weight)

            if available_tpm is not None and available_tpm < 0:
                available_tpm = 0

            available_rpm: Optional[int] = None

            if remaining_model_rpm is not None:
                if active_projects is not None:
                    available_rpm = int(remaining_model_rpm * weight / active_projects)
                else:
                    available_rpm = int(remaining_model_rpm * weight)

            if available_rpm is not None and available_rpm < 0:
                available_rpm = 0
            return (
                available_tpm,
                available_rpm,
                remaining_model_tpm,
                remaining_model_rpm,
                active_projects,
            )
        except Exception as e:
            verbose_proxy_logger.exception(
                "litellm.proxy.hooks.dynamic_rate_limiter.py::check_available_usage: Exception occurred - {}".format(
                    str(e)
                )
            )
            return None, None, None, None, None

    async def async_pre_call_hook(
        self,
        user_api_key_dict: UserAPIKeyAuth,
        cache: DualCache,
        data: dict,
        call_type: CallTypesLiteral,
    ) -> Optional[
        Union[Exception, str, dict]
    ]:  # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
        """
        - For a model group
        - Check if tpm/rpm available
        - Raise RateLimitError if no tpm/rpm available
        """
        if "model" in data:
            key_priority: Optional[str] = user_api_key_dict.metadata.get(
                "priority", None
            )
            (
                available_tpm,
                available_rpm,
                model_tpm,
                model_rpm,
                active_projects,
            ) = await self.check_available_usage(
                model=data["model"], priority=key_priority
            )
            ### CHECK TPM ###
            if available_tpm is not None and available_tpm == 0:
                raise HTTPException(
                    status_code=429,
                    detail={
                        "error": "Key={} over available TPM={}. Model TPM={}, Active keys={}".format(
                            user_api_key_dict.api_key,
                            available_tpm,
                            model_tpm,
                            active_projects,
                        )
                    },
                )
            ### CHECK RPM ###
            elif available_rpm is not None and available_rpm == 0:
                raise HTTPException(
                    status_code=429,
                    detail={
                        "error": "Key={} over available RPM={}. Model RPM={}, Active keys={}".format(
                            user_api_key_dict.api_key,
                            available_rpm,
                            model_rpm,
                            active_projects,
                        )
                    },
                )
            elif available_rpm is not None or available_tpm is not None:
                ## UPDATE CACHE WITH ACTIVE PROJECT
                asyncio.create_task(
                    self.internal_usage_cache.async_set_cache_sadd(  # this is a set
                        model=data["model"],  # type: ignore
                        value=[user_api_key_dict.token or "default_key"],
                    )
                )
        return None

    async def async_post_call_success_hook(
        self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
    ):
        try:
            if isinstance(response, ModelResponse):
                model_info = self.llm_router.get_model_info(
                    id=response._hidden_params["model_id"]
                )
                assert (
                    model_info is not None
                ), "Model info for model with id={} is None".format(
                    response._hidden_params["model_id"]
                )
                key_priority: Optional[str] = user_api_key_dict.metadata.get(
                    "priority", None
                )
                (
                    available_tpm,
                    available_rpm,
                    model_tpm,
                    model_rpm,
                    active_projects,
                ) = await self.check_available_usage(
                    model=model_info["model_name"], priority=key_priority
                )
                response._hidden_params["additional_headers"] = (
                    {  # Add additional response headers - easier debugging
                        "x-litellm-model_group": model_info["model_name"],
                        "x-ratelimit-remaining-litellm-project-tokens": available_tpm,
                        "x-ratelimit-remaining-litellm-project-requests": available_rpm,
                        "x-ratelimit-remaining-model-tokens": model_tpm,
                        "x-ratelimit-remaining-model-requests": model_rpm,
                        "x-ratelimit-current-active-projects": active_projects,
                    }
                )

                return response
            return await super().async_post_call_success_hook(
                data=data,
                user_api_key_dict=user_api_key_dict,
                response=response,
            )
        except Exception as e:
            verbose_proxy_logger.exception(
                "litellm.proxy.hooks.dynamic_rate_limiter.py::async_post_call_success_hook(): Exception occured - {}".format(
                    str(e)
                )
            )
            return response
