import os
import sys
import traceback
from unittest import mock

from dotenv import load_dotenv

import litellm.proxy
import litellm.proxy.proxy_server

load_dotenv()
import io
import json
import os

# this file is to test litellm/proxy

sys.path.insert(
    0, os.path.abspath("../..")
)  # Adds the parent directory to the system path
import asyncio
import logging

import pytest

import litellm
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding

# Configure logging
logging.basicConfig(
    level=logging.DEBUG,  # Set the desired logging level
    format="%(asctime)s - %(levelname)s - %(message)s",
)

from unittest.mock import AsyncMock, patch

from fastapi import FastAPI

# test /chat/completion request to the proxy
from fastapi.testclient import TestClient

from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.proxy_server import (  # Replace with the actual module where your FastAPI router is defined
    app,
    initialize,
    save_worker_config,
)
from litellm.proxy.utils import ProxyLogging

# Your bearer token
token = "sk-1234"

headers = {"Authorization": f"Bearer {token}"}

example_completion_result = {
    "choices": [
        {
            "message": {
                "content": "Whispers of the wind carry dreams to me.",
                "role": "assistant",
            }
        }
    ],
}
example_embedding_result = {
    "object": "list",
    "data": [
        {
            "object": "embedding",
            "index": 0,
            "embedding": [
                -0.006929283495992422,
                -0.005336422007530928,
                -4.547132266452536e-05,
                -0.024047505110502243,
                -0.006929283495992422,
                -0.005336422007530928,
                -4.547132266452536e-05,
                -0.024047505110502243,
                -0.006929283495992422,
                -0.005336422007530928,
                -4.547132266452536e-05,
                -0.024047505110502243,
            ],
        }
    ],
    "model": "text-embedding-3-small",
    "usage": {"prompt_tokens": 5, "total_tokens": 5},
}
example_image_generation_result = {
    "created": 1589478378,
    "data": [{"url": "https://..."}, {"url": "https://..."}],
}


def mock_patch_acompletion():
    return mock.patch(
        "litellm.proxy.proxy_server.llm_router.acompletion",
        return_value=example_completion_result,
    )


def mock_patch_aembedding():
    return mock.patch(
        "litellm.proxy.proxy_server.llm_router.aembedding",
        return_value=example_embedding_result,
    )


def mock_patch_aimage_generation():
    return mock.patch(
        "litellm.proxy.proxy_server.llm_router.aimage_generation",
        return_value=example_image_generation_result,
    )


@pytest.fixture(scope="function")
def fake_env_vars(monkeypatch):
    # Set some fake environment variables
    monkeypatch.setenv("OPENAI_API_KEY", "fake_openai_api_key")
    monkeypatch.setenv("OPENAI_API_BASE", "http://fake-openai-api-base")
    monkeypatch.setenv("AZURE_API_BASE", "http://fake-azure-api-base")
    monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake_azure_openai_api_key")
    monkeypatch.setenv("AZURE_SWEDEN_API_BASE", "http://fake-azure-sweden-api-base")
    monkeypatch.setenv("REDIS_HOST", "localhost")


@pytest.fixture(scope="function")
def client_no_auth(fake_env_vars):
    # Assuming litellm.proxy.proxy_server is an object
    from litellm.proxy.proxy_server import cleanup_router_config_variables

    cleanup_router_config_variables()
    filepath = os.path.dirname(os.path.abspath(__file__))
    config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
    # initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
    asyncio.run(initialize(config=config_fp, debug=True))
    return TestClient(app)


@mock_patch_acompletion()
def test_chat_completion(mock_acompletion, client_no_auth):
    global headers
    try:
        # Your test data
        test_data = {
            "model": "gpt-3.5-turbo",
            "messages": [
                {"role": "user", "content": "hi"},
            ],
            "max_tokens": 10,
        }

        print("testing proxy server with chat completions")
        response = client_no_auth.post("/v1/chat/completions", json=test_data)
        mock_acompletion.assert_called_once_with(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "user", "content": "hi"},
            ],
            max_tokens=10,
            litellm_call_id=mock.ANY,
            litellm_logging_obj=mock.ANY,
            request_timeout=mock.ANY,
            specific_deployment=True,
            metadata=mock.ANY,
            proxy_server_request=mock.ANY,
            secret_fields=mock.ANY,
        )
        print(f"response - {response.text}")
        assert response.status_code == 200
        result = response.json()
        print(f"Received response: {result}")
    except Exception as e:
        pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")


def test_chat_completion_malformed_messages_returns_400(client_no_auth):
    """
    Test that malformed messages (strings instead of dicts) return 400 instead of 500.
    
    This test verifies that when a client sends messages as raw strings instead of
    {role, content} objects, LiteLLM returns a 400 invalid_request_error instead
    of a 500 Internal Server Error.
    """
    global headers
    try:
        # Test data with malformed messages (string instead of dict)
        test_data = {
            "model": "gpt-3.5-turbo",
            "messages": ["hi how are you"],  # Invalid: should be [{"role": "user", "content": "hi how are you"}]
        }

        print("testing proxy server with malformed messages")
        response = client_no_auth.post("/v1/chat/completions", json=test_data, headers=headers)
        
        print(f"response status: {response.status_code}")
        print(f"response text: {response.text}")
        
        # Should return 400, not 500
        assert response.status_code == 400, f"Expected 400, got {response.status_code}. Response: {response.text}"
        
        # Verify error format
        result = response.json()
        assert "error" in result, "Response should contain 'error' key"
        error = result["error"]
        
        # Verify error type and message
        assert error.get("type") == "invalid_request_error" or error.get("type") is None, \
            f"Expected invalid_request_error or None, got {error.get('type')}"
        assert error.get("code") == "400" or error.get("code") == 400, \
            f"Expected code 400, got {error.get('code')}"
        
        # Error message should indicate invalid request format
        error_message = error.get("message", "")
        assert len(error_message) > 0, "Error message should not be empty"
        
    except Exception as e:
        pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")


def test_get_settings_request_timeout(client_no_auth):
    """
    When no timeout is set, it should use the litellm.request_timeout value
    """
    # Set a known value for litellm.request_timeout
    import litellm

    # Make a GET request to /settings
    response = client_no_auth.get("/settings")

    # Check if the request was successful
    assert response.status_code == 200

    # Parse the JSON response
    settings = response.json()
    print("settings", settings)

    assert settings["litellm.request_timeout"] == litellm.request_timeout


@pytest.mark.parametrize(
    "litellm_key_header_name",
    ["x-litellm-key", None],
)
def test_add_headers_to_request(litellm_key_header_name):
    from fastapi import Request
    from starlette.datastructures import URL
    import json
    from litellm.proxy.litellm_pre_call_utils import (
        clean_headers,
        LiteLLMProxyRequestSetup,
    )

    headers = {
        "Authorization": "Bearer 1234",
        "X-Custom-Header": "Custom-Value",
        "X-Stainless-Header": "Stainless-Value",
        "anthropic-beta": "beta-value",
    }
    request = Request(scope={"type": "http"})
    request._url = URL(url="/chat/completions")
    request._body = json.dumps({"model": "gpt-3.5-turbo"}).encode("utf-8")
    request_headers = clean_headers(headers, litellm_key_header_name)
    forwarded_headers = LiteLLMProxyRequestSetup._get_forwardable_headers(
        request_headers
    )
    assert forwarded_headers == {
        "X-Custom-Header": "Custom-Value",
        "anthropic-beta": "beta-value",
    }


@pytest.mark.parametrize(
    "litellm_key_header_name",
    ["x-litellm-key", None],
)
@pytest.mark.parametrize(
    "forward_headers",
    [True, False],
)
@mock_patch_acompletion()
def test_chat_completion_forward_headers(
    mock_acompletion, client_no_auth, litellm_key_header_name, forward_headers
):
    global headers
    try:
        if forward_headers:
            gs = getattr(litellm.proxy.proxy_server, "general_settings")
            gs["forward_client_headers_to_llm_api"] = True
            setattr(litellm.proxy.proxy_server, "general_settings", gs)
        if litellm_key_header_name is not None:
            gs = getattr(litellm.proxy.proxy_server, "general_settings")
            gs["litellm_key_header_name"] = litellm_key_header_name
            setattr(litellm.proxy.proxy_server, "general_settings", gs)
        # Your test data
        test_data = {
            "model": "gpt-3.5-turbo",
            "messages": [
                {"role": "user", "content": "hi"},
            ],
            "max_tokens": 10,
        }

        headers_to_forward = {
            "X-Custom-Header": "Custom-Value",
            "X-Another-Header": "Another-Value",
        }

        if litellm_key_header_name is not None:
            headers_to_not_forward = {litellm_key_header_name: "Bearer 1234"}
        else:
            headers_to_not_forward = {"Authorization": "Bearer 1234"}

        received_headers = {**headers_to_forward, **headers_to_not_forward}

        print("testing proxy server with chat completions")
        response = client_no_auth.post(
            "/v1/chat/completions", json=test_data, headers=received_headers
        )
        if not forward_headers:
            assert "headers" not in mock_acompletion.call_args.kwargs
        else:
            assert mock_acompletion.call_args.kwargs["headers"] == {
                "x-custom-header": "Custom-Value",
                "x-another-header": "Another-Value",
            }

        print(f"response - {response.text}")
        assert response.status_code == 200
        result = response.json()
        print(f"Received response: {result}")
    except Exception as e:
        pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")


@pytest.mark.parametrize("forward_llm_auth_headers", [True, False])
@mock_patch_acompletion()
def test_chat_completion_forward_llm_provider_auth_headers(
    mock_acompletion, client_no_auth, forward_llm_auth_headers
):
    """
    Test that LLM provider auth headers (x-api-key, x-goog-api-key) are forwarded
    when forward_llm_provider_auth_headers=True.
    
    This allows clients to send their own LLM provider API keys through the proxy.
    """
    try:
        # Configure general settings
        gs = getattr(litellm.proxy.proxy_server, "general_settings")
        gs["forward_client_headers_to_llm_api"] = True
        gs["forward_llm_provider_auth_headers"] = forward_llm_auth_headers
        setattr(litellm.proxy.proxy_server, "general_settings", gs)
        
        # Test data
        test_data = {
            "model": "gpt-3.5-turbo",
            "messages": [
                {"role": "user", "content": "hello"},
            ],
            "max_tokens": 10,
        }
        
        # Headers including LLM provider auth
        request_headers = {
            "Authorization": "Bearer sk-proxy-auth-123",  # Proxy auth (should be stripped)
            "x-api-key": "sk-ant-api03-test-anthropic-key",  # Anthropic API key
            "x-goog-api-key": "google-api-key-123",  # Google API key
            "X-Custom-Header": "custom-value",  # Custom header (should be forwarded)
        }
        
        # Make request
        response = client_no_auth.post(
            "/v1/chat/completions", json=test_data, headers=request_headers
        )
        
        assert response.status_code == 200
        
        # Check forwarded headers
        forwarded_headers = mock_acompletion.call_args.kwargs.get("headers", {})
        
        if forward_llm_auth_headers:
            # LLM provider auth headers should be forwarded
            assert "x-api-key" in forwarded_headers
            assert forwarded_headers["x-api-key"] == "sk-ant-api03-test-anthropic-key"
            assert "x-goog-api-key" in forwarded_headers
            assert forwarded_headers["x-goog-api-key"] == "google-api-key-123"
        else:
            # LLM provider auth headers should be stripped
            assert "x-api-key" not in forwarded_headers
            assert "x-goog-api-key" not in forwarded_headers
        
        # Custom headers should always be forwarded (when forward_client_headers_to_llm_api=True)
        assert "x-custom-header" in forwarded_headers
        assert forwarded_headers["x-custom-header"] == "custom-value"
        
        # Proxy Authorization should never be forwarded
        assert "authorization" not in forwarded_headers
        
        print(f"✓ Test passed with forward_llm_provider_auth_headers={forward_llm_auth_headers}")
        print(f"  Forwarded headers: {list(forwarded_headers.keys())}")
        
    except Exception as e:
        pytest.fail(f"Test failed with forward_llm_auth_headers={forward_llm_auth_headers}: {str(e)}")
    finally:
        # Clean up
        gs = getattr(litellm.proxy.proxy_server, "general_settings")
        gs.pop("forward_llm_provider_auth_headers", None)
        setattr(litellm.proxy.proxy_server, "general_settings", gs)


@mock_patch_acompletion()
@pytest.mark.asyncio
async def test_team_disable_guardrails(mock_acompletion, client_no_auth):
    """
    If team not allowed to turn on/off guardrails

    Raise 403 forbidden error, if request is made by team on `/key/generate` or `/chat/completions`.
    """
    import asyncio
    import json
    import time

    from fastapi import HTTPException, Request
    from starlette.datastructures import URL

    from litellm.proxy._types import (
        LiteLLM_TeamTable,
        LiteLLM_TeamTableCachedObj,
        ProxyException,
        UserAPIKeyAuth,
    )
    from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
    from litellm.proxy.proxy_server import hash_token, user_api_key_cache

    _team_id = "1234"
    user_key = "sk-12345678"

    valid_token = UserAPIKeyAuth(
        team_id=_team_id,
        team_blocked=True,
        token=hash_token(user_key),
        last_refreshed_at=time.time(),
    )
    await asyncio.sleep(1)
    team_obj = LiteLLM_TeamTableCachedObj(
        team_id=_team_id,
        blocked=False,
        last_refreshed_at=time.time(),
        metadata={"guardrails": {"modify_guardrails": False}},
    )
    user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
    user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)

    setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
    setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
    setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world")

    request = Request(scope={"type": "http"})
    request._url = URL(url="/chat/completions")

    body = {"metadata": {"guardrails": {"hide_secrets": False}}}
    json_bytes = json.dumps(body).encode("utf-8")

    request._body = json_bytes

    try:
        await user_api_key_auth(request=request, api_key="Bearer " + user_key)
        pytest.fail("Expected to raise 403 forbidden error.")
    except ProxyException as e:
        assert e.code == str(403)


from test_custom_callback_input import CompletionCustomHandler


@mock_patch_acompletion()
def test_custom_logger_failure_handler(mock_acompletion, client_no_auth):
    from litellm.proxy._types import UserAPIKeyAuth
    from litellm.proxy.proxy_server import hash_token, user_api_key_cache

    rpm_limit = 0

    mock_api_key = "sk-my-test-key"
    cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit)

    user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value)

    mock_logger = CustomLogger()
    mock_logger_unit_tests = CompletionCustomHandler()
    proxy_logging_obj: ProxyLogging = getattr(
        litellm.proxy.proxy_server, "proxy_logging_obj"
    )

    litellm.callbacks = [mock_logger, mock_logger_unit_tests]
    proxy_logging_obj._init_litellm_callbacks(llm_router=None)

    setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
    setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
    setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
    setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)

    with patch.object(
        mock_logger, "async_log_failure_event", new=AsyncMock()
    ) as mock_failed_alert:
        # Your test data
        test_data = {
            "model": "gpt-3.5-turbo",
            "messages": [
                {"role": "user", "content": "hi"},
            ],
            "max_tokens": 10,
        }

        print("testing proxy server with chat completions")
        response = client_no_auth.post(
            "/v1/chat/completions",
            json=test_data,
            headers={"Authorization": "Bearer {}".format(mock_api_key)},
        )
        assert response.status_code == 429

        # confirm async_log_failure_event is called
        mock_failed_alert.assert_called()

        assert len(mock_logger_unit_tests.errors) == 0


@mock_patch_acompletion()
def test_engines_model_chat_completions(mock_acompletion, client_no_auth):
    global headers
    try:
        # Your test data
        test_data = {
            "model": "gpt-3.5-turbo",
            "messages": [
                {"role": "user", "content": "hi"},
            ],
            "max_tokens": 10,
        }

        print("testing proxy server with chat completions")
        response = client_no_auth.post(
            "/engines/gpt-3.5-turbo/chat/completions", json=test_data
        )
        mock_acompletion.assert_called_once_with(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "user", "content": "hi"},
            ],
            max_tokens=10,
            litellm_call_id=mock.ANY,
            litellm_logging_obj=mock.ANY,
            request_timeout=mock.ANY,
            specific_deployment=True,
            metadata=mock.ANY,
            proxy_server_request=mock.ANY,
            secret_fields=mock.ANY,
        )
        print(f"response - {response.text}")
        assert response.status_code == 200
        result = response.json()
        print(f"Received response: {result}")
    except Exception as e:
        pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")


@mock_patch_acompletion()
def test_chat_completion_azure(mock_acompletion, client_no_auth):
    global headers
    try:
        # Your test data
        test_data = {
            "model": "azure/gpt-4.1-mini",
            "messages": [
                {"role": "user", "content": "write 1 sentence poem"},
            ],
            "max_tokens": 10,
        }

        print("testing proxy server with Azure Request /chat/completions")
        response = client_no_auth.post("/v1/chat/completions", json=test_data)

        mock_acompletion.assert_called_once_with(
            model="azure/gpt-4.1-mini",
            messages=[
                {"role": "user", "content": "write 1 sentence poem"},
            ],
            max_tokens=10,
            litellm_call_id=mock.ANY,
            litellm_logging_obj=mock.ANY,
            request_timeout=mock.ANY,
            specific_deployment=True,
            metadata=mock.ANY,
            proxy_server_request=mock.ANY,
            secret_fields=mock.ANY,
        )
        assert response.status_code == 200
        result = response.json()
        print(f"Received response: {result}")
        assert len(result["choices"][0]["message"]["content"]) > 0
    except Exception as e:
        pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")


# Run the test
# test_chat_completion_azure()


@mock_patch_acompletion()
def test_openai_deployments_model_chat_completions_azure(
    mock_acompletion, client_no_auth
):
    global headers
    try:
        # Your test data
        test_data = {
            "model": "azure/gpt-4.1-mini",
            "messages": [
                {"role": "user", "content": "write 1 sentence poem"},
            ],
            "max_tokens": 10,
        }

        url = "/openai/deployments/azure/gpt-4.1-mini/chat/completions"
        print(f"testing proxy server with Azure Request {url}")
        response = client_no_auth.post(url, json=test_data)

        mock_acompletion.assert_called_once_with(
            model="azure/gpt-4.1-mini",
            messages=[
                {"role": "user", "content": "write 1 sentence poem"},
            ],
            max_tokens=10,
            litellm_call_id=mock.ANY,
            litellm_logging_obj=mock.ANY,
            request_timeout=mock.ANY,
            specific_deployment=True,
            metadata=mock.ANY,
            proxy_server_request=mock.ANY,
            secret_fields=mock.ANY,
        )
        assert response.status_code == 200
        result = response.json()
        print(f"Received response: {result}")
        assert len(result["choices"][0]["message"]["content"]) > 0
    except Exception as e:
        pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")


# Run the test
# test_openai_deployments_model_chat_completions_azure()


### EMBEDDING
@mock_patch_aembedding()
def test_embedding(mock_aembedding, client_no_auth):
    global headers
    from litellm.proxy.proxy_server import user_custom_auth

    try:
        test_data = {
            "model": "azure/text-embedding-ada-002",
            "input": ["good morning from litellm"],
        }

        async def _pre_call_hook_side_effect(**kwargs):
            data = kwargs["data"]
            metadata = {**(data.get("metadata") or {}), "source": "unit-test"}
            data["metadata"] = metadata
            proxy_request = {**(data.get("proxy_server_request") or {})}
            proxy_request["path"] = "/v1/embeddings"
            data["proxy_server_request"] = proxy_request
            return data

        async def _post_call_success_side_effect(**kwargs):
            return kwargs["response"]

        with patch.object(
            litellm.proxy.proxy_server.proxy_logging_obj,
            "pre_call_hook",
            new=AsyncMock(side_effect=_pre_call_hook_side_effect),
        ) as mock_pre_call_hook, patch.object(
            litellm.proxy.proxy_server.proxy_logging_obj,
            "during_call_hook",
            new=AsyncMock(return_value=None),
        ) as mock_during_hook, patch.object(
            litellm.proxy.proxy_server.proxy_logging_obj,
            "post_call_success_hook",
            new=AsyncMock(side_effect=_post_call_success_side_effect),
        ):
            response = client_no_auth.post("/v1/embeddings", json=test_data)

        mock_aembedding.assert_called_once_with(
            model="azure/text-embedding-ada-002",
            input=["good morning from litellm"],
            specific_deployment=True,
            litellm_call_id=mock.ANY,
            litellm_logging_obj=mock.ANY,
            request_timeout=mock.ANY,
            metadata=mock.ANY,
            proxy_server_request=mock.ANY,
            secret_fields=mock.ANY,
        )
        assert response.status_code == 200
        result = response.json()
        print(len(result["data"][0]["embedding"]))
        assert len(result["data"][0]["embedding"]) > 10  # this usually has len==1536 so

        call_metadata = mock_aembedding.call_args.kwargs["metadata"]
        assert call_metadata.get("source") == "unit-test"

        pre_call_kwargs = mock_pre_call_hook.await_args_list[0].kwargs
        assert (
            pre_call_kwargs.get("call_type") == "aembedding"
        ), f"expected pre_call_hook to receive call_type='aembedding', got {pre_call_kwargs.get('call_type')}"

    except Exception as e:
        pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")


@mock_patch_aembedding()
def test_bedrock_embedding(mock_aembedding, client_no_auth):
    global headers
    from litellm.proxy.proxy_server import user_custom_auth

    try:
        test_data = {
            "model": "amazon-embeddings",
            "input": ["good morning from litellm"],
        }

        response = client_no_auth.post("/v1/embeddings", json=test_data)

        mock_aembedding.assert_called_once_with(
            model="amazon-embeddings",
            input=["good morning from litellm"],
            litellm_call_id=mock.ANY,
            litellm_logging_obj=mock.ANY,
            request_timeout=mock.ANY,
            metadata=mock.ANY,
            proxy_server_request=mock.ANY,
            secret_fields=mock.ANY,
        )
        assert response.status_code == 200
        print(response.status_code, response.text)
        result = response.json()
        print(len(result["data"][0]["embedding"]))
        assert len(result["data"][0]["embedding"]) > 10  # this usually has len==1536 so
    except Exception as e:
        pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")


@pytest.mark.skip(reason="AWS Suspended Account")
def test_sagemaker_embedding(client_no_auth):
    global headers
    from litellm.proxy.proxy_server import user_custom_auth

    try:
        test_data = {
            "model": "GPT-J 6B - Sagemaker Text Embedding (Internal)",
            "input": ["good morning from litellm"],
        }

        response = client_no_auth.post("/v1/embeddings", json=test_data)

        assert response.status_code == 200
        result = response.json()
        print(len(result["data"][0]["embedding"]))
        assert len(result["data"][0]["embedding"]) > 10  # this usually has len==1536 so
    except Exception as e:
        pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")


# Run the test
# test_embedding()
#### IMAGE GENERATION


@mock_patch_aimage_generation()
def test_img_gen(mock_aimage_generation, client_no_auth):
    global headers
    from litellm.proxy.proxy_server import user_custom_auth

    try:
        test_data = {
            "model": "dall-e-3",
            "prompt": "A cute baby sea otter",
            "n": 1,
            "size": "1024x1024",
        }

        response = client_no_auth.post("/v1/images/generations", json=test_data)

        mock_aimage_generation.assert_called_once_with(
            model="dall-e-3",
            prompt="A cute baby sea otter",
            n=1,
            size="1024x1024",
            metadata=mock.ANY,
            proxy_server_request=mock.ANY,
            secret_fields=mock.ANY,
        )
        assert response.status_code == 200
        result = response.json()
        print(len(result["data"][0]["url"]))
        assert len(result["data"][0]["url"]) > 10
    except Exception as e:
        pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")


#### ADDITIONAL
@pytest.mark.skip(reason="test via docker tests. Requires prisma client.")
def test_add_new_model(client_no_auth):
    global headers
    try:
        test_data = {
            "model_name": "test_openai_models",
            "litellm_params": {
                "model": "gpt-3.5-turbo",
            },
            "model_info": {"description": "this is a test openai model"},
        }
        client_no_auth.post("/model/new", json=test_data, headers=headers)
        response = client_no_auth.get("/model/info", headers=headers)
        assert response.status_code == 200
        result = response.json()
        print(f"response: {result}")
        model_info = None
        for m in result["data"]:
            if m["model_name"] == "test_openai_models":
                model_info = m["model_info"]
        assert model_info["description"] == "this is a test openai model"
    except Exception as e:
        pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")


@pytest.mark.xdist_group("proxy_heavy")
def test_health(client_no_auth):
    global headers
    import logging
    import time

    from litellm._logging import verbose_logger, verbose_proxy_logger

    verbose_proxy_logger.setLevel(logging.DEBUG)

    try:
        response = client_no_auth.get("/health")
        assert response.status_code == 200
    except Exception as e:
        pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")


# test_add_new_model()

from litellm.integrations.custom_logger import CustomLogger


class MyCustomHandler(CustomLogger):
    def log_pre_api_call(self, model, messages, kwargs):
        print(f"Pre-API Call")

    def log_success_event(self, kwargs, response_obj, start_time, end_time):
        print(f"On Success")
        assert kwargs["user"] == "proxy-user"
        assert kwargs["model"] == "gpt-3.5-turbo"
        assert kwargs["max_tokens"] == 10


customHandler = MyCustomHandler()


@mock_patch_acompletion()
def test_chat_completion_optional_params(mock_acompletion, client_no_auth):
    # [PROXY: PROD TEST] - DO NOT DELETE
    # This tests if all the /chat/completion params are passed to litellm
    try:
        # Your test data
        litellm.set_verbose = True
        test_data = {
            "model": "gpt-3.5-turbo",
            "messages": [
                {"role": "user", "content": "hi"},
            ],
            "max_tokens": 10,
            "user": "proxy-user",
        }

        litellm.callbacks = [customHandler]
        print("testing proxy server: optional params")
        response = client_no_auth.post("/v1/chat/completions", json=test_data)
        mock_acompletion.assert_called_once_with(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "user", "content": "hi"},
            ],
            max_tokens=10,
            user="proxy-user",
            litellm_call_id=mock.ANY,
            litellm_logging_obj=mock.ANY,
            request_timeout=mock.ANY,
            specific_deployment=True,
            metadata=mock.ANY,
            proxy_server_request=mock.ANY,
            secret_fields=mock.ANY,
        )
        assert response.status_code == 200
        result = response.json()
        print(f"Received response: {result}")
    except Exception as e:
        pytest.fail("LiteLLM Proxy test failed. Exception", e)


# Run the test
# test_chat_completion_optional_params()


# Test Reading config.yaml file
from litellm.proxy.proxy_server import ProxyConfig


@pytest.mark.skip(reason="local variable conflicts. needs to be refactored.")
@mock.patch("litellm.proxy.proxy_server.litellm.Cache")
def test_load_router_config(mock_cache, fake_env_vars):
    mock_cache.return_value.cache.__dict__ = {"redis_client": None}
    mock_cache.return_value.supported_call_types = [
        "completion",
        "acompletion",
        "embedding",
        "aembedding",
        "atranscription",
        "transcription",
    ]

    try:
        import asyncio

        print("testing reading config")
        # this is a basic config.yaml with only a model
        filepath = os.path.dirname(os.path.abspath(__file__))
        proxy_config = ProxyConfig()
        result = asyncio.run(
            proxy_config.load_config(
                router=None,
                config_file_path=f"{filepath}/example_config_yaml/simple_config.yaml",
            )
        )
        print(result)
        assert len(result[1]) == 1

        # this is a load balancing config yaml
        result = asyncio.run(
            proxy_config.load_config(
                router=None,
                config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml",
            )
        )
        print(result)
        assert len(result[1]) == 2

        # config with general settings - custom callbacks
        result = asyncio.run(
            proxy_config.load_config(
                router=None,
                config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml",
            )
        )
        print(result)
        assert len(result[1]) == 2

        # tests for litellm.cache set from config
        print("testing reading proxy config for cache")
        litellm.cache = None
        asyncio.run(
            proxy_config.load_config(
                router=None,
                config_file_path=f"{filepath}/example_config_yaml/cache_no_params.yaml",
            )
        )
        assert litellm.cache is not None
        assert "redis_client" in vars(
            litellm.cache.cache
        )  # it should default to redis on proxy
        assert litellm.cache.supported_call_types == [
            "completion",
            "acompletion",
            "embedding",
            "aembedding",
            "atranscription",
            "transcription",
        ]  # init with all call types

        litellm.disable_cache()

        print("testing reading proxy config for cache with params")
        mock_cache.return_value.supported_call_types = [
            "embedding",
            "aembedding",
        ]
        asyncio.run(
            proxy_config.load_config(
                router=None,
                config_file_path=f"{filepath}/example_config_yaml/cache_with_params.yaml",
            )
        )
        assert litellm.cache is not None
        print(litellm.cache)
        print(litellm.cache.supported_call_types)
        print(vars(litellm.cache.cache))
        assert "redis_client" in vars(
            litellm.cache.cache
        )  # it should default to redis on proxy
        assert litellm.cache.supported_call_types == [
            "embedding",
            "aembedding",
        ]  # init with all call types

    except Exception as e:
        pytest.fail(
            f"Proxy: Got exception reading config: {str(e)}\n{traceback.format_exc()}"
        )


# test_load_router_config()


@pytest.mark.asyncio
async def test_team_update_redis():
    """
    Tests if team update, updates the redis cache if set
    """
    from litellm.caching.caching import DualCache, RedisCache
    from litellm.proxy._types import LiteLLM_TeamTableCachedObj
    from litellm.proxy.auth.auth_checks import _cache_team_object

    proxy_logging_obj: ProxyLogging = getattr(
        litellm.proxy.proxy_server, "proxy_logging_obj"
    )

    redis_cache = RedisCache(host="localhost")

    with patch.object(
        redis_cache,
        "async_set_cache",
        new=AsyncMock(),
    ) as mock_client:
        await _cache_team_object(
            team_id="1234",
            team_table=LiteLLM_TeamTableCachedObj(team_id="1234"),
            user_api_key_cache=DualCache(redis_cache=redis_cache),
            proxy_logging_obj=proxy_logging_obj,
        )

        mock_client.assert_called()


@pytest.mark.asyncio
async def test_get_team_redis(client_no_auth):
    """
    Tests if get_team_object gets value from redis cache, if set
    """
    from litellm.caching.caching import DualCache, RedisCache
    from litellm.proxy.auth.auth_checks import get_team_object

    proxy_logging_obj: ProxyLogging = getattr(
        litellm.proxy.proxy_server, "proxy_logging_obj"
    )

    redis_cache = RedisCache()

    from fastapi import HTTPException

    with patch.object(
        redis_cache,
        "async_get_cache",
        new=AsyncMock(),
    ) as mock_client:
        try:
            await get_team_object(
                team_id="1234",
                user_api_key_cache=DualCache(redis_cache=redis_cache),
                parent_otel_span=None,
                proxy_logging_obj=proxy_logging_obj,
                prisma_client=AsyncMock(),
            )
        except HTTPException:
            pass

        mock_client.assert_called_once()


import random
from litellm._uuid import uuid
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch

from litellm.proxy._types import (
    LitellmUserRoles,
    NewUserRequest,
    TeamMemberAddRequest,
    UserAPIKeyAuth,
)
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
from test_key_generate_prisma import prisma_client


@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
@pytest.mark.parametrize(
    "user_role",
    [LitellmUserRoles.INTERNAL_USER.value, LitellmUserRoles.PROXY_ADMIN.value],
)
@pytest.mark.asyncio
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
async def test_create_user_default_budget(prisma_client, user_role):

    setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
    setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
    setattr(litellm, "max_internal_user_budget", 10)
    setattr(litellm, "internal_user_budget_duration", "5m")
    await litellm.proxy.proxy_server.prisma_client.connect()
    user = f"ishaan {uuid.uuid4().hex}"
    request = NewUserRequest(
        user_id=user, user_role=user_role
    )  # create a key with no budget
    with patch.object(
        litellm.proxy.proxy_server.prisma_client, "insert_data", new=AsyncMock()
    ) as mock_client:
        await new_user(
            request,
        )

        mock_client.assert_called()

        print(f"mock_client.call_args: {mock_client.call_args}")
        print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))

        if user_role == LitellmUserRoles.INTERNAL_USER.value:
            assert (
                mock_client.call_args.kwargs["data"]["max_budget"]
                == litellm.max_internal_user_budget
            )
            assert (
                mock_client.call_args.kwargs["data"]["budget_duration"]
                == litellm.internal_user_budget_duration
            )

        else:
            assert mock_client.call_args.kwargs["data"]["max_budget"] is None
            assert mock_client.call_args.kwargs["data"]["budget_duration"] is None


@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"])
@pytest.mark.asyncio
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
async def test_create_team_member_add(prisma_client, new_member_method):
    import time

    from fastapi import Request

    from litellm.proxy._types import LiteLLM_TeamTableCachedObj, LiteLLM_UserTable
    from litellm.proxy.proxy_server import hash_token, user_api_key_cache

    setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
    setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
    setattr(litellm, "max_internal_user_budget", 10)
    setattr(litellm, "internal_user_budget_duration", "5m")
    await litellm.proxy.proxy_server.prisma_client.connect()
    user = f"ishaan {uuid.uuid4().hex}"
    _team_id = "litellm-test-client-id-new"
    team_obj = LiteLLM_TeamTableCachedObj(
        team_id=_team_id,
        blocked=False,
        last_refreshed_at=time.time(),
        metadata={"guardrails": {"modify_guardrails": False}},
    )
    # user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
    user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)

    setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
    if new_member_method == "user_id":
        data = {
            "team_id": _team_id,
            "member": [{"role": "user", "user_id": user}],
        }
    elif new_member_method == "user_email":
        data = {
            "team_id": _team_id,
            "member": [{"role": "user", "user_email": user}],
        }
    team_member_add_request = TeamMemberAddRequest(**data)

    with patch(
        "litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
        new_callable=AsyncMock,
    ) as mock_litellm_usertable, patch(
        "litellm.proxy.auth.auth_checks._get_team_object_from_user_api_key_cache",
        new=AsyncMock(return_value=team_obj),
    ) as mock_team_obj, patch(
        "litellm.proxy.proxy_server.prisma_client.get_data",
        new=AsyncMock(return_value=[]),
    ) as mock_get_data:

        mock_client = AsyncMock(
            return_value=LiteLLM_UserTable(
                user_id="1234", max_budget=100, user_email="1234"
            )
        )
        mock_litellm_usertable.upsert = mock_client
        mock_litellm_usertable.find_many = AsyncMock(return_value=None)
        # Mock find_first for user_email validation (returns None for new users)
        mock_litellm_usertable.find_first = AsyncMock(return_value=None)
        # Mock find_unique for user_id validation (returns None for new users)
        mock_litellm_usertable.find_unique = AsyncMock(return_value=None)
        team_mock_client = AsyncMock()
        original_val = getattr(
            litellm.proxy.proxy_server.prisma_client.db, "litellm_teamtable"
        )
        litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client

        team_mock_client.update = AsyncMock(
            return_value=LiteLLM_TeamTableCachedObj(team_id="1234")
        )

        print(f"team_member_add_request={team_member_add_request}")
        await team_member_add(
            data=team_member_add_request,
            user_api_key_dict=UserAPIKeyAuth(user_role="proxy_admin"),
        )

        mock_client.assert_called()

        print(f"mock_client.call_args: {mock_client.call_args}")
        print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))

        assert (
            mock_client.call_args.kwargs["data"]["create"]["max_budget"]
            == litellm.max_internal_user_budget
        )
        assert (
            mock_client.call_args.kwargs["data"]["create"]["budget_duration"]
            == litellm.internal_user_budget_duration
        )

        litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = original_val


@pytest.mark.parametrize("team_member_role", ["admin", "user"])
@pytest.mark.parametrize("team_route", ["/team/member_add", "/team/member_delete"])
@pytest.mark.asyncio
async def test_create_team_member_add_team_admin_user_api_key_auth(
    prisma_client, team_member_role, team_route
):
    import time

    from fastapi import Request

    from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
    from litellm.proxy.proxy_server import (
        ProxyException,
        hash_token,
        user_api_key_auth,
        user_api_key_cache,
    )

    setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
    setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
    setattr(litellm, "max_internal_user_budget", 10)
    setattr(litellm, "internal_user_budget_duration", "5m")
    await litellm.proxy.proxy_server.prisma_client.connect()
    user = f"ishaan {uuid.uuid4().hex}"
    _team_id = "litellm-test-client-id-new"
    user_key = "sk-12345678"

    valid_token = UserAPIKeyAuth(
        team_id=_team_id,
        token=hash_token(user_key),
        team_member=Member(role=team_member_role, user_id=user),
        last_refreshed_at=time.time(),
    )
    user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)

    team_obj = LiteLLM_TeamTableCachedObj(
        team_id=_team_id,
        blocked=False,
        last_refreshed_at=time.time(),
        metadata={"guardrails": {"modify_guardrails": False}},
    )

    user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)

    setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)

    ## TEST IF TEAM ADMIN ALLOWED TO CALL /MEMBER_ADD ENDPOINT
    import json

    from starlette.datastructures import URL

    request = Request(scope={"type": "http"})
    request._url = URL(url=team_route)

    body = {}
    json_bytes = json.dumps(body).encode("utf-8")

    request._body = json_bytes

    ## ALLOWED BY USER_API_KEY_AUTH
    await user_api_key_auth(request=request, api_key="Bearer " + user_key)


@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"])
@pytest.mark.parametrize("user_role", ["admin", "user"])
@pytest.mark.asyncio
async def test_create_team_member_add_team_admin(
    prisma_client, new_member_method, user_role
):
    """
    Relevant issue - https://github.com/BerriAI/litellm/issues/5300

    Allow team admins to:
        - Add and remove team members
        - raise error if team member not an existing 'internal_user'
    """
    import time

    from fastapi import Request

    from litellm.proxy._types import (
        LiteLLM_TeamTableCachedObj,
        LiteLLM_UserTable,
        Member,
    )
    from litellm.proxy.proxy_server import (
        HTTPException,
        ProxyException,
        hash_token,
        user_api_key_auth,
        user_api_key_cache,
    )

    setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
    setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
    setattr(litellm, "max_internal_user_budget", 10)
    setattr(litellm, "internal_user_budget_duration", "5m")
    await litellm.proxy.proxy_server.prisma_client.connect()
    user = f"ishaan {uuid.uuid4().hex}"
    _team_id = "litellm-test-client-id-new"
    user_key = "sk-12345678"
    team_admin = f"krrish {uuid.uuid4().hex}"

    valid_token = UserAPIKeyAuth(
        team_id=_team_id,
        user_id=team_admin,
        token=hash_token(user_key),
        last_refreshed_at=time.time(),
    )
    user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)

    team_obj = LiteLLM_TeamTableCachedObj(
        team_id=_team_id,
        blocked=False,
        last_refreshed_at=time.time(),
        members_with_roles=[Member(role=user_role, user_id=team_admin)],
        metadata={"guardrails": {"modify_guardrails": False}},
    )

    user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)

    setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
    if new_member_method == "user_id":
        data = {
            "team_id": _team_id,
            "member": [{"role": "user", "user_id": user}],
        }
    elif new_member_method == "user_email":
        data = {
            "team_id": _team_id,
            "member": [{"role": "user", "user_email": user}],
        }
    team_member_add_request = TeamMemberAddRequest(**data)

    with patch(
        "litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
        new_callable=AsyncMock,
    ) as mock_litellm_usertable, patch(
        "litellm.proxy.auth.auth_checks._get_team_object_from_user_api_key_cache",
        new=AsyncMock(return_value=team_obj),
    ) as mock_team_obj, patch(
        "litellm.proxy.proxy_server.prisma_client.get_data",
        new=AsyncMock(return_value=[]),
    ) as mock_get_data:
        mock_client = AsyncMock(
            return_value=LiteLLM_UserTable(
                user_id="1234", max_budget=100, user_email="1234"
            )
        )
        mock_litellm_usertable.upsert = mock_client
        mock_litellm_usertable.find_many = AsyncMock(return_value=None)
        # Mock find_first for user_email validation (returns None for new users)
        mock_litellm_usertable.find_first = AsyncMock(return_value=None)
        # Mock find_unique for user_id validation (returns None for new users)
        mock_litellm_usertable.find_unique = AsyncMock(return_value=None)

        team_mock_client = AsyncMock()
        original_val = getattr(
            litellm.proxy.proxy_server.prisma_client.db, "litellm_teamtable"
        )
        litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client

        team_mock_client.update = AsyncMock(
            return_value=LiteLLM_TeamTableCachedObj(team_id="1234")
        )

        try:
            await team_member_add(
                data=team_member_add_request,
                user_api_key_dict=valid_token,
            )
        except HTTPException as e:
            if user_role == "user":
                assert e.status_code == 403
                return
            else:
                raise e

        mock_client.assert_called()

        print(f"mock_client.call_args: {mock_client.call_args}")
        print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))

        assert (
            mock_client.call_args.kwargs["data"]["create"]["max_budget"]
            == litellm.max_internal_user_budget
        )
        assert (
            mock_client.call_args.kwargs["data"]["create"]["budget_duration"]
            == litellm.internal_user_budget_duration
        )

        litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = original_val


@pytest.mark.asyncio
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
async def test_user_info_team_list(prisma_client):
    """Assert user_info for admin calls team_list function"""
    from litellm.proxy._types import LiteLLM_UserTable

    setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
    setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
    await litellm.proxy.proxy_server.prisma_client.connect()

    from litellm.proxy.management_endpoints.internal_user_endpoints import user_info

    with patch(
        "litellm.proxy.management_endpoints.team_endpoints.list_team",
        new_callable=AsyncMock,
    ) as mock_client:

        prisma_client.get_data = AsyncMock(
            return_value=LiteLLM_UserTable(
                user_role="proxy_admin",
                user_id="default_user_id",
                max_budget=None,
                user_email="",
            )
        )

        try:
            await user_info(
                request=MagicMock(),
                user_id=None,
                user_api_key_dict=UserAPIKeyAuth(
                    api_key="sk-1234", user_id="default_user_id"
                ),
            )
        except Exception:
            pass

        mock_client.assert_called()


@pytest.mark.skip(reason="Local test")
@pytest.mark.asyncio
async def test_add_callback_via_key(prisma_client):
    """
    Test if callback specified in key, is used.
    """
    global headers
    import json

    from fastapi import HTTPException, Request, Response
    from starlette.datastructures import URL

    from litellm.proxy.proxy_server import chat_completion

    setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
    setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
    await litellm.proxy.proxy_server.prisma_client.connect()

    litellm.set_verbose = True

    try:
        # Your test data
        test_data = {
            "model": "azure/gpt-4.1-mini",
            "messages": [
                {"role": "user", "content": "write 1 sentence poem"},
            ],
            "max_tokens": 10,
            "mock_response": "Hello world",
            "api_key": "my-fake-key",
        }

        request = Request(scope={"type": "http", "method": "POST", "headers": {}})
        request._url = URL(url="/chat/completions")

        json_bytes = json.dumps(test_data).encode("utf-8")

        request._body = json_bytes

        with patch.object(
            litellm.litellm_core_utils.litellm_logging,
            "LangFuseLogger",
            new=MagicMock(),
        ) as mock_client:
            resp = await chat_completion(
                request=request,
                fastapi_response=Response(),
                user_api_key_dict=UserAPIKeyAuth(
                    metadata={
                        "logging": [
                            {
                                "callback_name": "langfuse",  # 'otel', 'langfuse', 'lunary'
                                "callback_type": "success",  # set, if required by integration - future improvement, have logging tools work for success + failure by default
                                "callback_vars": {
                                    "langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY",
                                    "langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY",
                                    "langfuse_host": "https://us.cloud.langfuse.com",
                                },
                            }
                        ]
                    }
                ),
            )
            print(resp)
            mock_client.assert_called()
            mock_client.return_value.log_event.assert_called()
            args, kwargs = mock_client.return_value.log_event.call_args
            kwargs = kwargs["kwargs"]
            assert "user_api_key_metadata" in kwargs["litellm_params"]["metadata"]
            assert (
                "logging"
                in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"]
            )
            checked_keys = False
            for item in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"][
                "logging"
            ]:
                for k, v in item["callback_vars"].items():
                    print("k={}, v={}".format(k, v))
                    if "key" in k:
                        assert "os.environ" in v
                        checked_keys = True

            assert checked_keys
    except Exception as e:
        pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "callback_type, expected_success_callbacks, expected_failure_callbacks",
    [
        ("success", ["langfuse"], []),
        ("failure", [], ["langfuse"]),
        ("success_and_failure", ["langfuse"], ["langfuse"]),
    ],
)
async def test_add_callback_via_key_litellm_pre_call_utils(
    prisma_client, callback_type, expected_success_callbacks, expected_failure_callbacks
):
    import json

    from fastapi import HTTPException, Request, Response
    from starlette.datastructures import URL

    from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request

    setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
    setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
    await litellm.proxy.proxy_server.prisma_client.connect()

    proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")

    request = Request(scope={"type": "http", "method": "POST", "headers": {}})
    request._url = URL(url="/chat/completions")

    test_data = {
        "model": "azure/gpt-4.1-mini",
        "messages": [
            {"role": "user", "content": "write 1 sentence poem"},
        ],
        "max_tokens": 10,
        "mock_response": "Hello world",
        "api_key": "my-fake-key",
    }

    json_bytes = json.dumps(test_data).encode("utf-8")

    request._body = json_bytes

    data = {
        "data": {
            "model": "azure/gpt-4.1-mini",
            "messages": [{"role": "user", "content": "write 1 sentence poem"}],
            "max_tokens": 10,
            "mock_response": "Hello world",
            "api_key": "my-fake-key",
        },
        "request": request,
        "user_api_key_dict": UserAPIKeyAuth(
            token=None,
            key_name=None,
            key_alias=None,
            spend=0.0,
            max_budget=None,
            expires=None,
            models=[],
            aliases={},
            config={},
            user_id=None,
            team_id=None,
            max_parallel_requests=None,
            metadata={
                "logging": [
                    {
                        "callback_name": "langfuse",
                        "callback_type": callback_type,
                        "callback_vars": {
                            "langfuse_public_key": "my-mock-public-key",
                            "langfuse_secret_key": "my-mock-secret-key",
                            "langfuse_host": "https://us.cloud.langfuse.com",
                        },
                    }
                ]
            },
            tpm_limit=None,
            rpm_limit=None,
            budget_duration=None,
            budget_reset_at=None,
            allowed_cache_controls=[],
            permissions={},
            model_spend={},
            model_max_budget={},
            soft_budget_cooldown=False,
            litellm_budget_table=None,
            org_id=None,
            team_spend=None,
            team_alias=None,
            team_tpm_limit=None,
            team_rpm_limit=None,
            team_max_budget=None,
            team_models=[],
            team_blocked=False,
            soft_budget=None,
            team_model_aliases=None,
            team_member_spend=None,
            team_metadata=None,
            end_user_id=None,
            end_user_tpm_limit=None,
            end_user_rpm_limit=None,
            end_user_max_budget=None,
            last_refreshed_at=None,
            api_key=None,
            user_role=None,
            allowed_model_region=None,
            parent_otel_span=None,
        ),
        "proxy_config": proxy_config,
        "general_settings": {},
        "version": "0.0.0",
    }

    new_data = await add_litellm_data_to_request(**data)
    print("NEW DATA: {}".format(new_data))

    assert "langfuse_public_key" in new_data
    assert new_data["langfuse_public_key"] == "my-mock-public-key"
    assert "langfuse_secret_key" in new_data
    assert new_data["langfuse_secret_key"] == "my-mock-secret-key"

    if expected_success_callbacks:
        assert "success_callback" in new_data
        assert new_data["success_callback"] == expected_success_callbacks

    if expected_failure_callbacks:
        assert "failure_callback" in new_data
        assert new_data["failure_callback"] == expected_failure_callbacks


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "disable_fallbacks_set",
    [
        True,
        False,
    ],
)
async def test_disable_fallbacks_by_key(disable_fallbacks_set):
    from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup

    key_metadata = {"disable_fallbacks": disable_fallbacks_set}
    existing_data = {
        "model": "azure/gpt-4.1-mini",
        "messages": [{"role": "user", "content": "write 1 sentence poem"}],
    }
    data = LiteLLMProxyRequestSetup.add_key_level_controls(
        key_metadata=key_metadata,
        data=existing_data,
        _metadata_variable_name="metadata",
    )

    assert data["disable_fallbacks"] == disable_fallbacks_set


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "callback_type, expected_success_callbacks, expected_failure_callbacks",
    [
        ("success", ["gcs_bucket"], []),
        ("failure", [], ["gcs_bucket"]),
        ("success_and_failure", ["gcs_bucket"], ["gcs_bucket"]),
    ],
)
async def test_add_callback_via_key_litellm_pre_call_utils_gcs_bucket(
    prisma_client, callback_type, expected_success_callbacks, expected_failure_callbacks
):
    import json

    from fastapi import HTTPException, Request, Response
    from starlette.datastructures import URL

    from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request

    setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
    setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
    await litellm.proxy.proxy_server.prisma_client.connect()

    proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")

    request = Request(scope={"type": "http", "method": "POST", "headers": {}})
    request._url = URL(url="/chat/completions")

    test_data = {
        "model": "azure/gpt-4.1-mini",
        "messages": [
            {"role": "user", "content": "write 1 sentence poem"},
        ],
        "max_tokens": 10,
        "mock_response": "Hello world",
        "api_key": "my-fake-key",
    }

    json_bytes = json.dumps(test_data).encode("utf-8")

    request._body = json_bytes

    data = {
        "data": {
            "model": "azure/gpt-4.1-mini",
            "messages": [{"role": "user", "content": "write 1 sentence poem"}],
            "max_tokens": 10,
            "mock_response": "Hello world",
            "api_key": "my-fake-key",
        },
        "request": request,
        "user_api_key_dict": UserAPIKeyAuth(
            token=None,
            key_name=None,
            key_alias=None,
            spend=0.0,
            max_budget=None,
            expires=None,
            models=[],
            aliases={},
            config={},
            user_id=None,
            team_id=None,
            max_parallel_requests=None,
            metadata={
                "logging": [
                    {
                        "callback_name": "gcs_bucket",
                        "callback_type": callback_type,
                        "callback_vars": {
                            "gcs_bucket_name": "key-logging-project1",
                            "gcs_path_service_account": "pathrise-convert-1606954137718-a956eef1a2a8.json",
                        },
                    }
                ]
            },
            tpm_limit=None,
            rpm_limit=None,
            budget_duration=None,
            budget_reset_at=None,
            allowed_cache_controls=[],
            permissions={},
            model_spend={},
            model_max_budget={},
            soft_budget_cooldown=False,
            litellm_budget_table=None,
            org_id=None,
            team_spend=None,
            team_alias=None,
            team_tpm_limit=None,
            team_rpm_limit=None,
            team_max_budget=None,
            team_models=[],
            team_blocked=False,
            soft_budget=None,
            team_model_aliases=None,
            team_member_spend=None,
            team_metadata=None,
            end_user_id=None,
            end_user_tpm_limit=None,
            end_user_rpm_limit=None,
            end_user_max_budget=None,
            last_refreshed_at=None,
            api_key=None,
            user_role=None,
            allowed_model_region=None,
            parent_otel_span=None,
        ),
        "proxy_config": proxy_config,
        "general_settings": {},
        "version": "0.0.0",
    }

    new_data = await add_litellm_data_to_request(**data)
    print("NEW DATA: {}".format(new_data))

    assert "gcs_bucket_name" in new_data
    assert new_data["gcs_bucket_name"] == "key-logging-project1"
    assert "gcs_path_service_account" in new_data
    assert (
        new_data["gcs_path_service_account"]
        == "pathrise-convert-1606954137718-a956eef1a2a8.json"
    )

    if expected_success_callbacks:
        assert "success_callback" in new_data
        assert new_data["success_callback"] == expected_success_callbacks

    if expected_failure_callbacks:
        assert "failure_callback" in new_data
        assert new_data["failure_callback"] == expected_failure_callbacks


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "callback_type, expected_success_callbacks, expected_failure_callbacks",
    [
        ("success", ["langsmith"], []),
        ("failure", [], ["langsmith"]),
        ("success_and_failure", ["langsmith"], ["langsmith"]),
    ],
)
async def test_add_callback_via_key_litellm_pre_call_utils_langsmith(
    prisma_client, callback_type, expected_success_callbacks, expected_failure_callbacks
):
    import json

    from fastapi import HTTPException, Request, Response
    from starlette.datastructures import URL

    from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request

    setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
    setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
    await litellm.proxy.proxy_server.prisma_client.connect()

    proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")

    request = Request(scope={"type": "http", "method": "POST", "headers": {}})
    request._url = URL(url="/chat/completions")

    test_data = {
        "model": "azure/gpt-4.1-mini",
        "messages": [
            {"role": "user", "content": "write 1 sentence poem"},
        ],
        "max_tokens": 10,
        "mock_response": "Hello world",
        "api_key": "my-fake-key",
    }

    json_bytes = json.dumps(test_data).encode("utf-8")

    request._body = json_bytes

    data = {
        "data": {
            "model": "azure/gpt-4.1-mini",
            "messages": [{"role": "user", "content": "write 1 sentence poem"}],
            "max_tokens": 10,
            "mock_response": "Hello world",
            "api_key": "my-fake-key",
        },
        "request": request,
        "user_api_key_dict": UserAPIKeyAuth(
            token=None,
            key_name=None,
            key_alias=None,
            spend=0.0,
            max_budget=None,
            expires=None,
            models=[],
            aliases={},
            config={},
            user_id=None,
            team_id=None,
            max_parallel_requests=None,
            metadata={
                "logging": [
                    {
                        "callback_name": "langsmith",
                        "callback_type": callback_type,
                        "callback_vars": {
                            "langsmith_api_key": "ls-1234",
                            "langsmith_project": "pr-brief-resemblance-72",
                            "langsmith_base_url": "https://api.smith.langchain.com",
                        },
                    }
                ]
            },
            tpm_limit=None,
            rpm_limit=None,
            budget_duration=None,
            budget_reset_at=None,
            allowed_cache_controls=[],
            permissions={},
            model_spend={},
            model_max_budget={},
            soft_budget_cooldown=False,
            litellm_budget_table=None,
            org_id=None,
            team_spend=None,
            team_alias=None,
            team_tpm_limit=None,
            team_rpm_limit=None,
            team_max_budget=None,
            team_models=[],
            team_blocked=False,
            soft_budget=None,
            team_model_aliases=None,
            team_member_spend=None,
            team_metadata=None,
            end_user_id=None,
            end_user_tpm_limit=None,
            end_user_rpm_limit=None,
            end_user_max_budget=None,
            last_refreshed_at=None,
            api_key=None,
            user_role=None,
            allowed_model_region=None,
            parent_otel_span=None,
        ),
        "proxy_config": proxy_config,
        "general_settings": {},
        "version": "0.0.0",
    }

    new_data = await add_litellm_data_to_request(**data)
    print("NEW DATA: {}".format(new_data))

    assert "langsmith_api_key" in new_data
    assert new_data["langsmith_api_key"] == "ls-1234"
    assert "langsmith_project" in new_data
    assert new_data["langsmith_project"] == "pr-brief-resemblance-72"
    assert "langsmith_base_url" in new_data
    assert new_data["langsmith_base_url"] == "https://api.smith.langchain.com"

    if expected_success_callbacks:
        assert "success_callback" in new_data
        assert new_data["success_callback"] == expected_success_callbacks

    if expected_failure_callbacks:
        assert "failure_callback" in new_data
        assert new_data["failure_callback"] == expected_failure_callbacks


@pytest.mark.skipif(
    not os.getenv("GEMINI_API_KEY") and not os.getenv("GOOGLE_API_KEY"),
    reason="Requires GEMINI_API_KEY or GOOGLE_API_KEY.",
)
@pytest.mark.asyncio
async def test_gemini_pass_through_endpoint():
    from starlette.datastructures import URL

    from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
        Request,
        Response,
        gemini_proxy_route,
    )

    body = b"""
        {
            "contents": [{
                "parts":[{
                "text": "The quick brown fox jumps over the lazy dog."
                }]
                }]
        }
        """

    # Construct the scope dictionary
    scope = {
        "type": "http",
        "method": "POST",
        "path": "/gemini/v1beta/models/gemini-2.5-flash:countTokens",
        "query_string": b"key=sk-1234",
        "headers": [
            (b"content-type", b"application/json"),
        ],
    }

    # Create a new Request object
    async def async_receive():
        return {"type": "http.request", "body": body, "more_body": False}

    request = Request(
        scope=scope,
        receive=async_receive,
    )

    resp = await gemini_proxy_route(
        endpoint="v1beta/models/gemini-2.5-flash:countTokens?key=sk-1234",
        request=request,
        fastapi_response=Response(),
    )

    print(resp.body)


@pytest.mark.parametrize("hidden", [True, False])
@pytest.mark.asyncio
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
async def test_proxy_model_group_alias_checks(prisma_client, hidden):
    """
    Check if model group alias is returned on

    `/v1/models`
    `/v1/model/info`
    `/v1/model_group/info`
    """
    import json

    from fastapi import HTTPException, Request, Response
    from starlette.datastructures import URL

    from litellm.proxy.proxy_server import model_group_info, model_info_v1, model_list

    setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
    setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
    await litellm.proxy.proxy_server.prisma_client.connect()

    proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")

    _model_list = [
        {
            "model_name": "gpt-3.5-turbo",
            "litellm_params": {"model": "gpt-3.5-turbo"},
        }
    ]
    model_alias = "gpt-4"
    router = litellm.Router(
        model_list=_model_list,
        model_group_alias={model_alias: {"model": "gpt-3.5-turbo", "hidden": hidden}},
    )
    setattr(litellm.proxy.proxy_server, "llm_router", router)
    setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list)

    request = Request(scope={"type": "http", "method": "POST", "headers": {}})
    request._url = URL(url="/v1/models")

    resp = await model_list(
        user_api_key_dict=UserAPIKeyAuth(models=[]),
    )

    if hidden:
        assert len(resp["data"]) == 1
    else:
        assert len(resp["data"]) == 2
    print(resp)

    resp = await model_info_v1(
        user_api_key_dict=UserAPIKeyAuth(models=[]),
    )
    models = resp["data"]
    is_model_alias_in_list = False
    for item in models:
        if model_alias == item["model_name"]:
            is_model_alias_in_list = True

    if hidden:
        assert is_model_alias_in_list is False
    else:
        assert is_model_alias_in_list

    resp = await model_group_info(
        user_api_key_dict=UserAPIKeyAuth(models=[]),
    )
    print(f"resp: {resp}")
    models = resp["data"]
    is_model_alias_in_list = False
    print(f"model_alias: {model_alias}, models: {models}")
    for item in models:
        if model_alias == item.model_group:
            is_model_alias_in_list = True

    if hidden:
        assert is_model_alias_in_list is False
    else:
        assert is_model_alias_in_list, f"models: {models}"


@pytest.mark.asyncio
@pytest.mark.skip(reason="Requires reliable external DB connection (prisma).")
async def test_proxy_model_group_info_rerank(prisma_client):
    """
    Check if rerank model is returned on the following endpoints

    `/v1/models`
    `/v1/model/info`
    `/v1/model_group/info`
    """
    import json

    from fastapi import HTTPException, Request, Response
    from starlette.datastructures import URL

    from litellm.proxy.proxy_server import model_group_info, model_info_v1, model_list

    setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
    setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
    await litellm.proxy.proxy_server.prisma_client.connect()

    proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")

    _model_list = [
        {
            "model_name": "rerank-english-v3.0",
            "litellm_params": {"model": "cohere/rerank-english-v3.0"},
            "model_info": {
                "mode": "rerank",
            },
        }
    ]
    router = litellm.Router(model_list=_model_list)
    setattr(litellm.proxy.proxy_server, "llm_router", router)
    setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list)

    request = Request(scope={"type": "http", "method": "POST", "headers": {}})
    request._url = URL(url="/v1/models")

    resp = await model_list(
        user_api_key_dict=UserAPIKeyAuth(models=[]),
    )

    assert len(resp["data"]) == 1
    print(resp)

    resp = await model_info_v1(
        user_api_key_dict=UserAPIKeyAuth(models=[]),
    )
    models = resp["data"]
    assert models[0]["model_info"]["mode"] == "rerank"
    resp = await model_group_info(
        user_api_key_dict=UserAPIKeyAuth(models=[]),
    )

    print(resp)
    models = resp["data"]
    assert models[0].mode == "rerank"


# @pytest.mark.asyncio
# async def test_proxy_team_member_add(prisma_client):
#     """
#     Add 10 people to a team. Confirm all 10 are added.
#     """
#     from litellm.proxy.management_endpoints.team_endpoints import (
#         team_member_add,
#         new_team,
#     )
#     from litellm.proxy._types import TeamMemberAddRequest, Member, NewTeamRequest

#     setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
#     setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
#     try:

#         async def test():
#             await litellm.proxy.proxy_server.prisma_client.connect()
#             from litellm.proxy.proxy_server import user_api_key_cache

#             user_api_key_dict = UserAPIKeyAuth(
#                 user_role=LitellmUserRoles.PROXY_ADMIN,
#                 api_key="sk-1234",
#                 user_id="1234",
#             )

#             new_team()
#             for _ in range(10):
#                 request = TeamMemberAddRequest(
#                     team_id="1234",
#                     member=Member(
#                         user_id="1234",
#                         user_role=LitellmUserRoles.INTERNAL_USER,
#                     ),
#                 )
#                 key = await team_member_add(
#                     request, user_api_key_dict=user_api_key_dict
#                 )

#             print(key)
#             user_id = key.user_id

#             # check /user/info to verify user_role was set correctly
#             new_user_info = await user_info(
#                 user_id=user_id, user_api_key_dict=user_api_key_dict
#             )
#             new_user_info = new_user_info.user_info
#             print("new_user_info=", new_user_info)
#             assert new_user_info["user_role"] == LitellmUserRoles.INTERNAL_USER
#             assert new_user_info["user_id"] == user_id

#             generated_key = key.key
#             bearer_token = "Bearer " + generated_key

#             assert generated_key not in user_api_key_cache.in_memory_cache.cache_dict

#             value_from_prisma = await prisma_client.get_data(
#                 token=generated_key,
#             )
#             print("token from prisma", value_from_prisma)

#             request = Request(
#                 {
#                     "type": "http",
#                     "route": api_route,
#                     "path": api_route.path,
#                     "headers": [("Authorization", bearer_token)],
#                 }
#             )

#             # use generated key to auth in
#             result = await user_api_key_auth(request=request, api_key=bearer_token)
#             print("result from user auth with new key", result)

#         asyncio.run(test())
#     except Exception as e:
#         pytest.fail(f"An exception occurred - {str(e)}")


@pytest.mark.asyncio
async def test_proxy_server_prisma_setup():
    from litellm.proxy.proxy_server import ProxyStartupEvent, proxy_state
    from litellm.proxy.utils import ProxyLogging
    from litellm.caching import DualCache

    user_api_key_cache = DualCache()

    with patch.object(
        litellm.proxy.proxy_server, "PrismaClient", new=MagicMock()
    ) as mock_prisma_client:
        mock_client = mock_prisma_client.return_value  # This is the mocked instance
        mock_client.connect = AsyncMock()  # Mock the connect method
        mock_client.check_view_exists = AsyncMock()  # Mock the check_view_exists method
        mock_client.health_check = AsyncMock()  # Mock the health_check method
        mock_client._set_spend_logs_row_count_in_proxy_state = (
            AsyncMock()
        )  # Mock the _set_spend_logs_row_count_in_proxy_state method
        mock_client.start_db_health_watchdog_task = AsyncMock()
        # Mock the db attribute with start_token_refresh_task for RDS IAM token refresh
        mock_db = MagicMock()
        mock_db.start_token_refresh_task = AsyncMock()
        mock_client.db = mock_db

        await ProxyStartupEvent._setup_prisma_client(
            database_url=os.getenv("DATABASE_URL"),
            proxy_logging_obj=ProxyLogging(user_api_key_cache=user_api_key_cache),
            user_api_key_cache=user_api_key_cache,
        )

        # Verify our mocked methods were called
        mock_client.connect.assert_called_once()
        mock_client.check_view_exists.assert_called_once()

        # Note: This is REALLY IMPORTANT to check that the health check is called
        # This is how we ensure the DB is ready before proceeding
        mock_client.health_check.assert_called_once()

        # check that the spend logs row count is set in proxy state
        mock_client._set_spend_logs_row_count_in_proxy_state.assert_called_once()
        assert proxy_state.get_proxy_state_variable("spend_logs_row_count") is not None


@pytest.mark.asyncio
async def test_proxy_server_prisma_setup_invalid_db():
    """
    PROD TEST: Test that proxy server startup fails when it's unable to connect to the database

    Think 2-3 times before editing / deleting this test, it's important for PROD
    """
    from litellm.proxy.proxy_server import ProxyStartupEvent
    from litellm.proxy.utils import ProxyLogging
    from litellm.caching import DualCache

    user_api_key_cache = DualCache()
    invalid_db_url = "postgresql://invalid:invalid@localhost:5432/nonexistent"

    _old_db_url = os.getenv("DATABASE_URL")
    os.environ["DATABASE_URL"] = invalid_db_url

    with pytest.raises(Exception) as exc_info:
        await ProxyStartupEvent._setup_prisma_client(
            database_url=invalid_db_url,
            proxy_logging_obj=ProxyLogging(user_api_key_cache=user_api_key_cache),
            user_api_key_cache=user_api_key_cache,
        )
        print("GOT EXCEPTION=", exc_info)

        assert "httpx.ConnectError" in str(exc_info.value)

    # # Verify the error message indicates a database connection issue
    # assert any(x in str(exc_info.value).lower() for x in ["database", "connection", "authentication"])

    if _old_db_url:
        os.environ["DATABASE_URL"] = _old_db_url


@pytest.mark.asyncio
async def test_get_ui_settings_spend_logs_threshold():
    """
    Test that get_ui_settings correctly sets DISABLE_EXPENSIVE_DB_QUERIES based on spend_logs_row_count threshold
    """
    from litellm.proxy.management_endpoints.ui_sso import get_ui_settings
    from litellm.proxy.proxy_server import proxy_state
    from fastapi import Request
    from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY

    # Create a mock request
    mock_request = Request(
        scope={
            "type": "http",
            "headers": [],
            "method": "GET",
            "scheme": "http",
            "server": ("testserver", 80),
            "path": "/sso/get/ui_settings",
            "query_string": b"",
        }
    )

    # Test case 1: When spend_logs_row_count > MAX_SPENDLOG_ROWS_TO_QUERY
    proxy_state.set_proxy_state_variable(
        "spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY + 1
    )
    response = await get_ui_settings(mock_request)
    print("response from get_ui_settings", json.dumps(response, indent=4))
    assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is True
    assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY + 1

    # Test case 2: When spend_logs_row_count < MAX_SPENDLOG_ROWS_TO_QUERY
    proxy_state.set_proxy_state_variable(
        "spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY - 1
    )
    response = await get_ui_settings(mock_request)
    print("response from get_ui_settings", json.dumps(response, indent=4))
    assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is False
    assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY - 1

    # Test case 3: Edge case - exactly MAX_SPENDLOG_ROWS_TO_QUERY
    proxy_state.set_proxy_state_variable(
        "spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY
    )
    response = await get_ui_settings(mock_request)
    print("response from get_ui_settings", json.dumps(response, indent=4))
    assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is False
    assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY

    # Clean up
    proxy_state.set_proxy_state_variable("spend_logs_row_count", 0)


@pytest.mark.asyncio
async def test_run_background_health_check_reflects_llm_model_list(monkeypatch):
    """
    Test that _run_background_health_check reflects changes to llm_model_list in each health check iteration.
    """
    import litellm.proxy.proxy_server as proxy_server
    import copy

    test_model_list_1 = [{"model_name": "model-a"}]
    test_model_list_2 = [{"model_name": "model-b"}]
    called_model_lists = []

    async def fake_perform_health_check(
        model_list, details, max_concurrency=None
    ):
        called_model_lists.append(copy.deepcopy(model_list))
        return (["healthy"], ["unhealthy"])

    monkeypatch.setattr(proxy_server, "health_check_interval", 1)
    monkeypatch.setattr(proxy_server, "health_check_details", None)
    monkeypatch.setattr(
        proxy_server, "llm_model_list", copy.deepcopy(test_model_list_1)
    )
    monkeypatch.setattr(proxy_server, "perform_health_check", fake_perform_health_check)
    monkeypatch.setattr(proxy_server, "health_check_results", {})

    async def fake_sleep(interval):
        raise asyncio.CancelledError()

    monkeypatch.setattr(asyncio, "sleep", fake_sleep)

    try:
        await proxy_server._run_background_health_check()
    except asyncio.CancelledError:
        pass

    monkeypatch.setattr(
        proxy_server, "llm_model_list", copy.deepcopy(test_model_list_2)
    )

    try:
        await proxy_server._run_background_health_check()
    except asyncio.CancelledError:
        pass

    assert len(called_model_lists) >= 2
    assert called_model_lists[0] == test_model_list_1
    assert called_model_lists[1] == test_model_list_2


@pytest.mark.asyncio
async def test_background_health_check_skip_disabled_models(monkeypatch):
    """Ensure models with disable_background_health_check are skipped."""
    import litellm.proxy.proxy_server as proxy_server
    import copy

    test_model_list = [
        {"model_name": "model-a"},
        {"model_name": "model-b", "model_info": {"disable_background_health_check": True}},
    ]
    called_model_lists = []

    async def fake_perform_health_check(
        model_list, details, max_concurrency=None
    ):
        called_model_lists.append(copy.deepcopy(model_list))
        return (["healthy"], [])

    monkeypatch.setattr(proxy_server, "health_check_interval", 1)
    monkeypatch.setattr(proxy_server, "health_check_details", None)
    monkeypatch.setattr(proxy_server, "llm_model_list", copy.deepcopy(test_model_list))
    monkeypatch.setattr(proxy_server, "perform_health_check", fake_perform_health_check)
    monkeypatch.setattr(proxy_server, "health_check_results", {})

    async def fake_sleep(interval):
        raise asyncio.CancelledError()

    monkeypatch.setattr(asyncio, "sleep", fake_sleep)

    try:
        await proxy_server._run_background_health_check()
    except asyncio.CancelledError:
        pass

    assert called_model_lists == [[{"model_name": "model-a"}]]


def test_get_timeout_from_request():
    from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup

    headers = {
        "x-litellm-timeout": "90",
    }
    timeout = LiteLLMProxyRequestSetup._get_timeout_from_request(headers)
    assert timeout == 90

    headers = {
        "x-litellm-timeout": "90.5",
    }
    timeout = LiteLLMProxyRequestSetup._get_timeout_from_request(headers)
    assert timeout == 90.5


@pytest.mark.parametrize(
    "ui_exists, ui_has_content",
    [
        (True, True),   # UI path exists and has content
        (True, False),  # UI path exists but is empty
        (False, False), # UI path doesn't exist
    ],
)
def test_non_root_ui_path_logic(monkeypatch, tmp_path, ui_exists, ui_has_content):
    """
    Test the non-root Docker UI path detection logic.
    
    Tests that when LITELLM_NON_ROOT is set to "true":
    - If UI path exists and has content, it should be used
    - If UI path doesn't exist or is empty, proper error logging occurs
    """
    import tempfile
    import shutil
    from unittest.mock import MagicMock
    
    # Create a temporary directory to act as /tmp/litellm_ui
    test_ui_path = tmp_path / "litellm_ui"
    
    if ui_exists:
        test_ui_path.mkdir(parents=True, exist_ok=True)
        if ui_has_content:
            # Create some dummy files to simulate built UI
            (test_ui_path / "index.html").write_text("<html></html>")
            (test_ui_path / "app.js").write_text("console.log('test');")
    
    # Mock the environment variable and os.path operations
    monkeypatch.setenv("LITELLM_NON_ROOT", "true")
    
    # Create a mock logger to capture log messages
    mock_logger = MagicMock()
    
    # We need to reimport or reload the relevant code section
    # Since this is module-level code, we'll test the logic directly
    ui_path = None
    non_root_ui_path = str(test_ui_path)
    
    # Simulate the logic from proxy_server.py lines 909-920
    if os.getenv("LITELLM_NON_ROOT", "").lower() == "true":
        if os.path.exists(non_root_ui_path) and os.listdir(non_root_ui_path):
            mock_logger.info(f"Using pre-built UI for non-root Docker: {non_root_ui_path}")
            mock_logger.info(f"UI files found: {len(os.listdir(non_root_ui_path))} items")
            ui_path = non_root_ui_path
        else:
            mock_logger.error(f"UI not found at {non_root_ui_path}. UI will not be available.")
            mock_logger.error(f"Path exists: {os.path.exists(non_root_ui_path)}, Has content: {os.path.exists(non_root_ui_path) and bool(os.listdir(non_root_ui_path))}")
    
    # Verify behavior based on test parameters
    if ui_exists and ui_has_content:
        # UI should be found and used
        assert ui_path == non_root_ui_path
        assert mock_logger.info.call_count == 2
        mock_logger.info.assert_any_call(f"Using pre-built UI for non-root Docker: {non_root_ui_path}")
        # Verify the second info call mentions the number of items
        info_calls = [call[0][0] for call in mock_logger.info.call_args_list]
        assert any("UI files found:" in call and "items" in call for call in info_calls)
        assert mock_logger.error.call_count == 0
    else:
        # UI should not be found, error should be logged
        assert ui_path is None
        assert mock_logger.error.call_count == 2
        mock_logger.error.assert_any_call(f"UI not found at {non_root_ui_path}. UI will not be available.")
        # Verify the second error call has path existence info
        error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
        assert any("Path exists:" in call for call in error_calls)
        assert mock_logger.info.call_count == 0


@pytest.mark.asyncio
async def test_get_config_callbacks_with_all_types(client_no_auth):
    """
    Test that /get/config/callbacks returns all three callback types:
    - success_callback with type="success"
    - failure_callback with type="failure"  
    - callbacks (success_and_failure) with type="success_and_failure"
    """
    from litellm.proxy.proxy_server import ProxyConfig
    
    # Create a mock config with all three callback types
    mock_config_data = {
        "litellm_settings": {
            "success_callback": ["langfuse", "braintrust"],
            "failure_callback": ["sentry"],
            "callbacks": ["otel", "langsmith"]
        },
        "environment_variables": {
            "LANGFUSE_PUBLIC_KEY": "test-public-key",
            "LANGFUSE_SECRET_KEY": "test-secret-key",
            "LANGFUSE_HOST": "https://test.langfuse.com",
            "BRAINTRUST_API_KEY": "test-braintrust-key",
            "OTEL_EXPORTER": "otlp",
            "OTEL_ENDPOINT": "http://localhost:4317",
            "LANGSMITH_API_KEY": "test-langsmith-key",
        },
        "general_settings": {}
    }
    
    proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
    
    with patch.object(
        proxy_config, "get_config", new=AsyncMock(return_value=mock_config_data)
    ):
        response = client_no_auth.get("/get/config/callbacks")
        
        assert response.status_code == 200
        result = response.json()
        
        # Verify response structure
        assert "status" in result
        assert result["status"] == "success"
        assert "callbacks" in result
        
        callbacks = result["callbacks"]
        
        # Verify we have all 5 callbacks (2 success + 1 failure + 2 success_and_failure)
        assert len(callbacks) == 5
        
        # Group callbacks by type
        success_callbacks = [cb for cb in callbacks if cb.get("type") == "success"]
        failure_callbacks = [cb for cb in callbacks if cb.get("type") == "failure"]
        success_and_failure_callbacks = [cb for cb in callbacks if cb.get("type") == "success_and_failure"]
        
        # Verify all callbacks have required fields
        for callback in callbacks:
            assert "name" in callback
            assert "variables" in callback
            assert "type" in callback
            assert callback["type"] in ["success", "failure", "success_and_failure"]
        
        # Verify success callbacks
        assert len(success_callbacks) == 2
        success_names = [cb["name"] for cb in success_callbacks]
        assert "langfuse" in success_names
        assert "braintrust" in success_names
        
        # Verify failure callbacks
        assert len(failure_callbacks) == 1
        assert failure_callbacks[0]["name"] == "sentry"
        
        # Verify success_and_failure callbacks
        assert len(success_and_failure_callbacks) == 2
        success_and_failure_names = [cb["name"] for cb in success_and_failure_callbacks]
        assert "otel" in success_and_failure_names
        assert "langsmith" in success_and_failure_names


@pytest.mark.asyncio
async def test_get_config_callbacks_environment_variables(client_no_auth):
    """
    Test that /get/config/callbacks correctly includes environment variables
    for each callback type. Values are returned as-is from the config (no decryption).
    """
    from litellm.proxy.proxy_server import ProxyConfig
    
    # Create a mock config with callbacks and their env vars
    mock_config_data = {
        "litellm_settings": {
            "success_callback": ["langfuse"],
            "failure_callback": [],
            "callbacks": ["otel"]
        },
        "environment_variables": {
            "LANGFUSE_PUBLIC_KEY": "test-public-key",
            "LANGFUSE_SECRET_KEY": "test-secret-key",
            "LANGFUSE_HOST": "https://cloud.langfuse.com",
            "OTEL_EXPORTER": "otlp",
            "OTEL_ENDPOINT": "http://localhost:4317",
            "OTEL_HEADERS": "key=value",
        },
        "general_settings": {}
    }
    
    proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
    
    with patch.object(
        proxy_config, "get_config", new=AsyncMock(return_value=mock_config_data)
    ):
        response = client_no_auth.get("/get/config/callbacks")
        
        assert response.status_code == 200
        result = response.json()
        
        callbacks = result["callbacks"]
        
        # Find langfuse callback (success type)
        langfuse_callback = next(
            (cb for cb in callbacks if cb["name"] == "langfuse"), None
        )
        assert langfuse_callback is not None
        assert langfuse_callback["type"] == "success"
        assert "variables" in langfuse_callback
        
        # Verify langfuse env vars are present (values returned as-is, no decryption)
        langfuse_vars = langfuse_callback["variables"]
        assert "LANGFUSE_PUBLIC_KEY" in langfuse_vars
        assert langfuse_vars["LANGFUSE_PUBLIC_KEY"] == "test-public-key"
        assert "LANGFUSE_SECRET_KEY" in langfuse_vars
        assert langfuse_vars["LANGFUSE_SECRET_KEY"] == "test-secret-key"
        assert "LANGFUSE_HOST" in langfuse_vars
        assert langfuse_vars["LANGFUSE_HOST"] == "https://cloud.langfuse.com"
        
        # Find otel callback (success_and_failure type)
        otel_callback = next(
            (cb for cb in callbacks if cb["name"] == "otel"), None
        )
        assert otel_callback is not None
        assert otel_callback["type"] == "success_and_failure"
        assert "variables" in otel_callback
        
        # Verify otel env vars are present
        otel_vars = otel_callback["variables"]
        assert "OTEL_EXPORTER" in otel_vars
        assert otel_vars["OTEL_EXPORTER"] == "otlp"
        assert "OTEL_ENDPOINT" in otel_vars
        assert otel_vars["OTEL_ENDPOINT"] == "http://localhost:4317"
        assert "OTEL_HEADERS" in otel_vars
        assert otel_vars["OTEL_HEADERS"] == "key=value"


@pytest.mark.asyncio
async def test_update_config_success_callback_normalization():
    """
    Ensure success_callback values are normalized to lowercase when updating config.
    This prevents delete_callback (which searches lowercase) from failing on mixed case inputs like 'SQS'.
    """
    import litellm.proxy.proxy_server as proxy_server
    from litellm.proxy._types import ConfigYAML

    # Ensure feature is enabled and prisma_client is set
    setattr(proxy_server, "store_model_in_db", True)
    setattr(proxy_server, "proxy_logging_obj", MagicMock())

    class MockPrisma:
        def __init__(self):
            self.db = MagicMock()
            self.db.litellm_config = MagicMock()
            self.db.litellm_config.upsert = AsyncMock()

        # proxy_server.update_config expects this to be sync returning a dict
        def jsonify_object(self, obj):
            return obj

    setattr(proxy_server, "prisma_client", MockPrisma())

    class MockProxyConfig:
        def __init__(self):
            self.saved_config = None

        async def get_config(self):
            # Existing config has one lowercase callback already
            return {"litellm_settings": {"success_callback": ["langfuse"]}}

        async def save_config(self, new_config: dict):
            self.saved_config = new_config

        async def add_deployment(self, prisma_client=None, proxy_logging_obj=None):
            return None

    mock_proxy_config = MockProxyConfig()
    setattr(proxy_server, "proxy_config", mock_proxy_config)

    # Update config with mixed-case callbacks - expect normalization to lowercase
    config_update = ConfigYAML(litellm_settings={"success_callback": ["SQS", "sQs"]})
    await proxy_server.update_config(config_update)

    saved = mock_proxy_config.saved_config
    assert saved is not None, "save_config was not called"
    callbacks = saved["litellm_settings"]["success_callback"]

    # Deduped and normalized
    assert "sqs" in callbacks
    assert "SQS" not in callbacks
    assert "sQs" not in callbacks
    # Existing callback should still be present
    assert "langfuse" in callbacks


@pytest.mark.parametrize(
    "data",
    [
        {
            "model": {
                "model_name": "azure/gpt-4.1-mini",
                "litellm_params": {"model": "azure/gpt-4.1-mini"},
                "model_info": {"base_model": "gpt-4.1-mini"},
            },
            "expected": "gpt-4.1-mini",
        },
        {
            "model": {
                "model_name": "openai/gpt-4.1-mini",
                "litellm_params": {"model": "openai/gpt-4.1-mini"},
            },
            "expected": "openai/gpt-4.1-mini",
        },
        {
            "model": {
                "model_name": "openai/gpt-4.1-mini",
                "litellm_params": {"model": "openai/gpt-4.1-mini"},
                "model_info": {"base_model": "gpt-4.1-mini"},
            },
            "expected": "gpt-4.1-mini",
        },
        {
            "model": {
                "model_name": "claude-sonnet-4-5-20250929",
                "litellm_params": {"model": "anthropic/claude-sonnet-4-5@20250929"},
                "model_info": {"base_model": "anthropic/claude-sonnet-4-5-20250929"},
            },
            "expected": "anthropic/claude-sonnet-4-5-20250929",
        },
        {
            "model": {
                "model_name": "gemini-2.5-flash-001",
                "litellm_params": {"model": "gemini/gemini-2.5-flash@001"},
                "model_info": {"base_model": "gemini-2.5-flash-001"},
            },
            "expected": "gemini-2.5-flash-001",
        },
    ],
)
def test_get_litellm_model_info(data):
    from litellm.proxy.proxy_server import get_litellm_model_info

    model = data["model"]
    get_info_mock = MagicMock()

    with mock.patch(
        "litellm.get_model_info",
        new=get_info_mock,
    ):
        get_litellm_model_info(model=model)
        get_info_mock.assert_called_once_with(data["expected"])
