# This file runs a health check for the LLM, used on litellm/proxy

import asyncio
import logging
import random
import sys
import threading
import time
from typing import List, Optional

import litellm

logger = logging.getLogger(__name__)
from litellm.constants import DEFAULT_HEALTH_CHECK_PROMPT, HEALTH_CHECK_TIMEOUT_SECONDS

ILLEGAL_DISPLAY_PARAMS = [
    "messages",
    "api_key",
    "prompt",
    "input",
    "vertex_credentials",
    "aws_access_key_id",
    "aws_secret_access_key",
]

MINIMAL_DISPLAY_PARAMS = ["model", "mode_error"]


def _get_process_rss_mb() -> Optional[float]:
    """
    Get process RSS memory in MB.
    On Linux, ru_maxrss is in KB. On macOS, ru_maxrss is in bytes.
    """
    try:
        import resource

        ru_maxrss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
        if sys.platform == "darwin":
            return float(ru_maxrss) / (1024 * 1024)
        return float(ru_maxrss) / 1024
    except Exception:
        return None


def _rss_mb_for_log() -> str:
    rss_mb = _get_process_rss_mb()
    if rss_mb is None:
        return "unknown"
    return f"{rss_mb:.2f}"


def _get_random_llm_message():
    """
    Get a random message from the LLM.
    """
    messages = ["Hey how's it going?", "What's 1 + 1?"]

    return [{"role": "user", "content": random.choice(messages)}]


def _clean_endpoint_data(endpoint_data: dict, details: Optional[bool] = True):
    """
    Clean the endpoint data for display to users.
    """
    endpoint_data.pop("litellm_logging_obj", None)
    return (
        {k: v for k, v in endpoint_data.items() if k not in ILLEGAL_DISPLAY_PARAMS}
        if details is not False
        else {k: v for k, v in endpoint_data.items() if k in MINIMAL_DISPLAY_PARAMS}
    )


def filter_deployments_by_id(
    model_list: List,
) -> List:
    seen_ids = set()
    filtered_deployments = []

    for deployment in model_list:
        _model_info = deployment.get("model_info") or {}
        _id = _model_info.get("id") or None
        if _id is None:
            continue

        if _id not in seen_ids:
            seen_ids.add(_id)
            filtered_deployments.append(deployment)

    return filtered_deployments


async def run_with_timeout(task, timeout):
    try:
        return await asyncio.wait_for(task, timeout)
    except asyncio.TimeoutError:
        # `asyncio.wait_for()` already cancels only the awaited task on timeout.
        # Do not cancel unrelated sibling health check tasks.
        return {"error": "Timeout exceeded"}


async def _run_model_health_check(model: dict):
    litellm_params = model["litellm_params"]
    model_info = model.get("model_info", {})
    mode = model_info.get("mode", None)
    litellm_params = _update_litellm_params_for_health_check(model_info, litellm_params)
    timeout = model_info.get("health_check_timeout") or HEALTH_CHECK_TIMEOUT_SECONDS

    return await run_with_timeout(
        litellm.ahealth_check(
            litellm_params,
            mode=mode,
            prompt=DEFAULT_HEALTH_CHECK_PROMPT,
            input=["test from litellm"],
        ),
        timeout,
    )


async def _run_health_checks_with_bounded_concurrency(
    models: list, concurrency_limit: int
) -> tuple[list, int]:
    """
    Run health checks with at most `concurrency_limit` active tasks.
    Preserves result ordering to match `models`.
    """
    results: list = [None] * len(models)
    tasks_to_index: dict[asyncio.Task, int] = {}
    model_iter = iter(enumerate(models))
    peak_in_flight = 0

    def _schedule_next() -> bool:
        nonlocal peak_in_flight
        try:
            idx, next_model = next(model_iter)
        except StopIteration:
            return False
        task = asyncio.create_task(_run_model_health_check(next_model))
        tasks_to_index[task] = idx
        peak_in_flight = max(peak_in_flight, len(tasks_to_index))
        return True

    for _ in range(min(concurrency_limit, len(models))):
        _schedule_next()

    while tasks_to_index:
        done, _ = await asyncio.wait(
            set(tasks_to_index.keys()),
            return_when=asyncio.FIRST_COMPLETED,
        )
        for task in done:
            idx = tasks_to_index.pop(task)
            try:
                results[idx] = task.result()
            except Exception as e:
                results[idx] = e
            _schedule_next()

    return results, peak_in_flight


async def _perform_health_check(
    model_list: list,
    details: Optional[bool] = True,
    max_concurrency: Optional[int] = None,
    instrumentation_context: Optional[dict] = None,
):
    """
    Perform a health check for each model in the list.

    max_concurrency: Optional limit on concurrent health check requests.
    """

    instrumentation_context = instrumentation_context or {}
    instrumentation_enabled = bool(instrumentation_context.get("enabled", False))
    cycle_id = instrumentation_context.get("cycle_id", "unknown")
    source = instrumentation_context.get("source", "unknown")

    dispatch_mode = "unbounded"
    peak_in_flight = 0
    if isinstance(max_concurrency, int) and max_concurrency > 0:
        dispatch_mode = "bounded"
        results, peak_in_flight = await _run_health_checks_with_bounded_concurrency(
            model_list, max_concurrency
        )
    else:
        tasks = [
            asyncio.create_task(_run_model_health_check(model)) for model in model_list
        ]
        peak_in_flight = len(tasks)
        results = await asyncio.gather(*tasks, return_exceptions=True)

    if instrumentation_enabled:
        logger.debug(
            "health_check_dispatch_summary source=%s cycle_id=%s mode=%s model_count=%d max_concurrency=%s peak_in_flight=%d thread_count=%d rss_mb=%s",
            source,
            cycle_id,
            dispatch_mode,
            len(model_list),
            max_concurrency,
            peak_in_flight,
            threading.active_count(),
            _rss_mb_for_log(),
        )

    healthy_endpoints = []
    unhealthy_endpoints = []

    for is_healthy, model in zip(results, model_list):
        litellm_params = model["litellm_params"]

        if isinstance(is_healthy, dict) and "error" not in is_healthy:
            healthy_endpoints.append(
                _clean_endpoint_data({**litellm_params, **is_healthy}, details)
            )
        elif isinstance(is_healthy, dict):
            unhealthy_endpoints.append(
                _clean_endpoint_data({**litellm_params, **is_healthy}, details)
            )
        else:
            unhealthy_endpoints.append(_clean_endpoint_data(litellm_params, details))

    return healthy_endpoints, unhealthy_endpoints


def _update_litellm_params_for_health_check(
    model_info: dict, litellm_params: dict
) -> dict:
    """
    Update the litellm params for health check.

    - gets a short `messages` param for health check
    - updates the `model` param with the `health_check_model` if it exists Doc: https://docs.litellm.ai/docs/proxy/health#wildcard-routes
    - updates the `voice` param with the `health_check_voice` for `audio_speech` mode if it exists Doc: https://docs.litellm.ai/docs/proxy/health#text-to-speech-models
    - for Bedrock models with region routing (bedrock/region/model), strips the litellm routing prefix but preserves the model ID
    """
    litellm_params["messages"] = _get_random_llm_message()
    _health_check_max_tokens = model_info.get("health_check_max_tokens", None)
    if _health_check_max_tokens is not None:
        litellm_params["max_tokens"] = _health_check_max_tokens
    elif "*" not in (
        model_info.get("health_check_model") or litellm_params.get("model") or ""
    ):
        litellm_params["max_tokens"] = 1

    _health_check_model = model_info.get("health_check_model", None)
    if _health_check_model is not None:
        litellm_params["model"] = _health_check_model
    if model_info.get("mode", None) == "audio_speech":
        litellm_params["voice"] = model_info.get("health_check_voice", "alloy")

    # Handle Bedrock region routing format: bedrock/region/model
    # This is needed because health checks bypass get_llm_provider() for the model param
    # Issue #15807: Without this, health checks send "region/model" as the model ID to AWS
    # which causes: "bedrock-runtime.../model/us-west-2/mistral.../invoke" (region in model ID)
    #
    # However, we must preserve cross-region inference profile prefixes like "us.", "eu.", etc.
    # Issue: Stripping these breaks AWS requirement for inference profile IDs
    #
    # Must also preserve route prefixes (converse/, invoke/) and handlers (llama/, deepseek_r1/, etc.)
    if litellm_params["model"].startswith("bedrock/"):
        from litellm.llms.bedrock.common_utils import BedrockModelInfo

        model = litellm_params["model"]
        # Strip only the bedrock/ prefix (preserve routes like converse/, invoke/)
        if model.startswith("bedrock/"):
            model = model[8:]  # len("bedrock/") = 8

        # Now check for region routing and strip it if present
        # Need to handle formats like:
        # - "us-west-2/model" → "model"
        # - "converse/us-west-2/model" → "converse/model"
        # - "llama/arn:..." → "llama/arn:..." (preserve handler)
        #
        # Strategy: Check each path segment, remove regions, preserve everything else
        parts = model.split("/")
        filtered_parts = []

        for part in parts:
            # Skip AWS regions, keep everything else
            if part not in BedrockModelInfo.all_global_regions:
                filtered_parts.append(part)

        model = "/".join(filtered_parts)
        litellm_params["model"] = model

    return litellm_params


async def perform_health_check(
    model_list: list,
    model: Optional[str] = None,
    cli_model: Optional[str] = None,
    details: Optional[bool] = True,
    model_id: Optional[str] = None,
    max_concurrency: Optional[int] = None,
    instrumentation_context: Optional[dict] = None,
):
    """
    Perform a health check on the system.

    When model_id is provided, only the deployment with that id is checked
    (so models that share the same name but have different ids are checked separately).
    When model (name) is provided, all deployments matching that name are checked.

    Returns:
        (bool): True if the health check passes, False otherwise.
    """
    instrumentation_context = instrumentation_context or {}
    instrumentation_enabled = bool(instrumentation_context.get("enabled", False))
    cycle_id = instrumentation_context.get("cycle_id", "unknown")
    source = instrumentation_context.get("source", "unknown")

    if not model_list:
        if cli_model:
            model_list = [
                {"model_name": cli_model, "litellm_params": {"model": cli_model}}
            ]
        else:
            if instrumentation_enabled:
                logger.debug(
                    "health_check_cycle_skipped source=%s cycle_id=%s reason=no_models",
                    source,
                    cycle_id,
                )
            return [], []

    cycle_start_time = time.monotonic()
    requested_model_count = len(model_list)

    # Filter by model_id first so a single deployment is checked when id is specified
    if model_id is not None:
        _by_id = [
            x for x in model_list if (x.get("model_info") or {}).get("id") == model_id
        ]
        if _by_id:
            model_list = _by_id
    elif model is not None:
        _new_model_list = [
            x for x in model_list if x["litellm_params"]["model"] == model
        ]
        if _new_model_list == []:
            _new_model_list = [x for x in model_list if x["model_name"] == model]
        model_list = _new_model_list

    post_filter_model_count = len(model_list)
    model_list = filter_deployments_by_id(
        model_list=model_list
    )  # filter duplicate deployments (e.g. when model alias'es are used)
    deduped_model_count = len(model_list)

    if instrumentation_enabled:
        logger.debug(
            "health_check_cycle_start source=%s cycle_id=%s requested_model_count=%d post_model_filter_count=%d deduped_model_count=%d max_concurrency=%s thread_count=%d rss_mb=%s",
            source,
            cycle_id,
            requested_model_count,
            post_filter_model_count,
            deduped_model_count,
            max_concurrency,
            threading.active_count(),
            _rss_mb_for_log(),
        )

    try:
        healthy_endpoints, unhealthy_endpoints = await _perform_health_check(
            model_list,
            details,
            max_concurrency=max_concurrency,
            instrumentation_context=instrumentation_context,
        )
    except Exception:
        if instrumentation_enabled:
            logger.exception(
                "health_check_cycle_failed source=%s cycle_id=%s model_count=%d duration_ms=%.2f thread_count=%d rss_mb=%s",
                source,
                cycle_id,
                deduped_model_count,
                (time.monotonic() - cycle_start_time) * 1000,
                threading.active_count(),
                _rss_mb_for_log(),
            )
        raise

    if instrumentation_enabled:
        logger.debug(
            "health_check_cycle_complete source=%s cycle_id=%s model_count=%d healthy_count=%d unhealthy_count=%d duration_ms=%.2f thread_count=%d rss_mb=%s",
            source,
            cycle_id,
            deduped_model_count,
            len(healthy_endpoints),
            len(unhealthy_endpoints),
            (time.monotonic() - cycle_start_time) * 1000,
            threading.active_count(),
            _rss_mb_for_log(),
        )

    return healthy_endpoints, unhealthy_endpoints
