### What this tests ####
## This test asserts the type of data passed into each method of the custom callback handler
import asyncio
import inspect
import os
import sys
import time
import traceback
from litellm._uuid import uuid
from datetime import datetime

import pytest
from pydantic import BaseModel

sys.path.insert(0, os.path.abspath("../.."))
from typing import List, Literal, Optional, Union
from unittest.mock import AsyncMock, MagicMock, patch

import litellm
from litellm import Cache, completion, embedding
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import LiteLLMCommonStrings

# Test Scenarios (test across completion, streaming, embedding)
## 1: Pre-API-Call
## 2: Post-API-Call
## 3: On LiteLLM Call success
## 4: On LiteLLM Call failure
## 5. Caching

# Test models
## 1. OpenAI
## 2. Azure OpenAI
## 3. Non-OpenAI/Azure - e.g. Bedrock

# Test interfaces
## 1. litellm.completion() + litellm.embeddings()
## refer to test_custom_callback_input_router.py for the router +  proxy tests


class CompletionCustomHandler(
    CustomLogger
):  # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
    """
    The set of expected inputs to a custom handler for a
    """

    # Class variables or attributes
    def __init__(self):
        self.errors = []
        self.states: List[
            Literal[
                "sync_pre_api_call",
                "async_pre_api_call",
                "post_api_call",
                "sync_stream",
                "async_stream",
                "sync_success",
                "async_success",
                "sync_failure",
                "async_failure",
            ]
        ] = []

    def log_pre_api_call(self, model, messages, kwargs):
        try:
            self.states.append("sync_pre_api_call")
            ## MODEL
            assert isinstance(model, str)
            ## MESSAGES
            assert isinstance(messages, list)
            ## KWARGS
            assert isinstance(kwargs["model"], str)
            assert isinstance(kwargs["messages"], list)
            assert isinstance(kwargs["optional_params"], dict)
            assert isinstance(kwargs["litellm_params"], dict)
            assert isinstance(kwargs["start_time"], (datetime, type(None)))
            assert isinstance(kwargs["stream"], bool)
            assert isinstance(kwargs["user"], (str, type(None)))
            ### METADATA
            metadata_value = kwargs["litellm_params"].get("metadata")
            assert metadata_value is None or isinstance(metadata_value, dict)
            if metadata_value is not None:
                if litellm.turn_off_message_logging is True:
                    assert (
                        metadata_value["raw_request"]
                        is LiteLLMCommonStrings.redacted_by_litellm.value
                    )
                else:
                    assert "raw_request" not in metadata_value or isinstance(
                        metadata_value["raw_request"], str
                    )
        except Exception:
            print(f"Assertion Error: {traceback.format_exc()}")
            self.errors.append(traceback.format_exc())

    def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
        try:
            self.states.append("post_api_call")
            ## START TIME
            assert isinstance(start_time, datetime)
            ## END TIME
            assert end_time == None
            ## RESPONSE OBJECT
            assert response_obj == None
            ## KWARGS
            assert isinstance(kwargs["model"], str)
            assert isinstance(kwargs["messages"], list)
            assert isinstance(kwargs["optional_params"], dict)
            assert isinstance(kwargs["litellm_params"], dict)
            assert isinstance(kwargs["start_time"], (datetime, type(None)))
            assert isinstance(kwargs["stream"], bool)
            assert isinstance(kwargs["user"], (str, type(None)))
            assert isinstance(kwargs["input"], (list, dict, str))
            assert isinstance(kwargs["api_key"], (str, type(None)))
            assert (
                isinstance(
                    kwargs["original_response"],
                    (str, litellm.CustomStreamWrapper, BaseModel),
                )
                or inspect.iscoroutine(kwargs["original_response"])
                or inspect.isasyncgen(kwargs["original_response"])
            )
            assert isinstance(kwargs["additional_args"], (dict, type(None)))
            assert isinstance(kwargs["log_event_type"], str)
        except Exception:
            print(f"Assertion Error: {traceback.format_exc()}")
            self.errors.append(traceback.format_exc())

    async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
        try:
            self.states.append("async_stream")
            ## START TIME
            assert isinstance(start_time, datetime)
            ## END TIME
            assert isinstance(end_time, datetime)
            ## RESPONSE OBJECT
            assert isinstance(response_obj, litellm.ModelResponseStream)
            ## KWARGS
            assert isinstance(kwargs["model"], str)
            assert isinstance(kwargs["messages"], list) and isinstance(
                kwargs["messages"][0], dict
            )
            assert isinstance(kwargs["optional_params"], dict)
            assert isinstance(kwargs["litellm_params"], dict)
            assert isinstance(kwargs["start_time"], (datetime, type(None)))
            assert isinstance(kwargs["stream"], bool)
            assert isinstance(kwargs["user"], (str, type(None)))
            assert (
                isinstance(kwargs["input"], list)
                and isinstance(kwargs["input"][0], dict)
            ) or isinstance(kwargs["input"], (dict, str))
            assert isinstance(kwargs["api_key"], (str, type(None)))
            assert (
                isinstance(
                    kwargs["original_response"], (str, litellm.CustomStreamWrapper)
                )
                or inspect.isasyncgen(kwargs["original_response"])
                or inspect.iscoroutine(kwargs["original_response"])
            )
            assert isinstance(kwargs["additional_args"], (dict, type(None)))
            assert isinstance(kwargs["log_event_type"], str)
        except Exception:
            print(f"Assertion Error: {traceback.format_exc()}")
            self.errors.append(traceback.format_exc())

    def log_success_event(self, kwargs, response_obj, start_time, end_time):
        try:
            print(f"\n\nkwargs={kwargs}\n\n")
            print(
                json.dumps(kwargs, default=str)
            )  # this is a test to confirm no circular references are in the logging object

            self.states.append("sync_success")
            ## START TIME
            assert isinstance(start_time, datetime)
            ## END TIME
            assert isinstance(end_time, datetime)
            ## RESPONSE OBJECT
            assert isinstance(
                response_obj,
                (
                    litellm.ModelResponse,
                    litellm.EmbeddingResponse,
                    litellm.ImageResponse,
                ),
            )
            ## KWARGS
            assert isinstance(kwargs["model"], str)
            assert isinstance(kwargs["messages"], list) and isinstance(
                kwargs["messages"][0], dict
            )
            assert isinstance(kwargs["optional_params"], dict)
            assert isinstance(kwargs["litellm_params"], dict)
            assert isinstance(kwargs["litellm_params"]["api_base"], str)
            assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
            assert isinstance(kwargs["start_time"], (datetime, type(None)))
            assert isinstance(kwargs["stream"], bool)
            assert isinstance(kwargs["user"], (str, type(None)))
            assert (
                isinstance(kwargs["input"], list)
                and (
                    isinstance(kwargs["input"][0], dict)
                    or isinstance(kwargs["input"][0], str)
                )
            ) or isinstance(kwargs["input"], (dict, str))
            assert isinstance(kwargs["api_key"], (str, type(None)))
            assert isinstance(
                kwargs["original_response"],
                (str, litellm.CustomStreamWrapper, BaseModel),
            ), "Original Response={}. Allowed types=[str, litellm.CustomStreamWrapper, BaseModel]".format(
                kwargs["original_response"]
            )
            assert isinstance(kwargs["additional_args"], (dict, type(None)))
            assert isinstance(kwargs["log_event_type"], str)
            assert isinstance(kwargs["response_cost"], (float, type(None)))
        except Exception:
            print(f"Assertion Error: {traceback.format_exc()}")
            self.errors.append(traceback.format_exc())

    def log_failure_event(self, kwargs, response_obj, start_time, end_time):
        try:
            print(f"kwargs: {kwargs}")
            self.states.append("sync_failure")
            ## START TIME
            assert isinstance(start_time, datetime)
            ## END TIME
            assert isinstance(end_time, datetime)
            ## RESPONSE OBJECT
            assert response_obj == None
            ## KWARGS
            assert isinstance(kwargs["model"], str)
            assert isinstance(kwargs["messages"], list) and isinstance(
                kwargs["messages"][0], dict
            )

            assert isinstance(kwargs["optional_params"], dict)
            assert isinstance(kwargs["litellm_params"], dict)
            assert isinstance(kwargs["litellm_params"]["metadata"], Optional[dict])
            assert isinstance(kwargs["start_time"], (datetime, type(None)))
            assert isinstance(kwargs["stream"], bool)
            assert isinstance(kwargs["user"], (str, type(None)))
            assert (
                isinstance(kwargs["input"], list)
                and isinstance(kwargs["input"][0], dict)
            ) or isinstance(kwargs["input"], (dict, str))
            assert isinstance(kwargs["api_key"], (str, type(None)))
            assert (
                isinstance(
                    kwargs["original_response"], (str, litellm.CustomStreamWrapper)
                )
                or kwargs["original_response"] == None
            )
            assert isinstance(kwargs["additional_args"], (dict, type(None)))
            assert isinstance(kwargs["log_event_type"], str)
        except Exception:
            print(f"Assertion Error: {traceback.format_exc()}")
            self.errors.append(traceback.format_exc())

    async def async_log_pre_api_call(self, model, messages, kwargs):
        try:
            self.states.append("async_pre_api_call")
            ## MODEL
            assert isinstance(model, str)
            ## MESSAGES
            assert isinstance(messages, list) and isinstance(messages[0], dict)
            ## KWARGS
            assert isinstance(kwargs["model"], str)
            assert isinstance(kwargs["messages"], list) and isinstance(
                kwargs["messages"][0], dict
            )
            assert isinstance(kwargs["optional_params"], dict)
            assert isinstance(kwargs["litellm_params"], dict)
            assert isinstance(kwargs["start_time"], (datetime, type(None)))
            assert isinstance(kwargs["stream"], bool)
            assert isinstance(kwargs["user"], (str, type(None)))
        except Exception as e:
            print(f"Assertion Error: {traceback.format_exc()}")
            self.errors.append(traceback.format_exc())

    async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
        try:
            print(
                "in async_log_success_event", kwargs, response_obj, start_time, end_time
            )
            self.states.append("async_success")
            ## START TIME
            assert isinstance(start_time, datetime)
            ## END TIME
            assert isinstance(end_time, datetime)
            ## RESPONSE OBJECT
            assert isinstance(
                response_obj,
                (
                    litellm.ModelResponse,
                    litellm.EmbeddingResponse,
                    litellm.TextCompletionResponse,
                ),
            )
            ## KWARGS
            assert isinstance(kwargs["model"], str)
            assert isinstance(kwargs["messages"], list)
            assert isinstance(kwargs["optional_params"], dict)
            assert isinstance(kwargs["litellm_params"], dict)
            assert isinstance(kwargs["litellm_params"]["api_base"], str)
            assert isinstance(kwargs["start_time"], (datetime, type(None)))
            assert isinstance(kwargs["stream"], bool)
            assert isinstance(kwargs["completion_start_time"], datetime)
            assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
            assert isinstance(kwargs["user"], (str, type(None)))
            assert isinstance(kwargs["input"], (list, dict, str))
            assert isinstance(kwargs["api_key"], (str, type(None)))
            assert (
                isinstance(
                    kwargs["original_response"], (str, litellm.CustomStreamWrapper)
                )
                or inspect.isasyncgen(kwargs["original_response"])
                or inspect.iscoroutine(kwargs["original_response"])
            )
            assert isinstance(kwargs["additional_args"], (dict, type(None)))
            assert isinstance(kwargs["log_event_type"], str)
            assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
            assert isinstance(kwargs["response_cost"], (float, type(None)))
        except Exception:
            print(f"Assertion Error: {traceback.format_exc()}")
            self.errors.append(traceback.format_exc())

    async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
        try:
            self.states.append("async_failure")
            ## START TIME
            assert isinstance(start_time, datetime)
            ## END TIME
            assert isinstance(end_time, datetime)
            ## RESPONSE OBJECT
            assert response_obj == None
            ## KWARGS
            assert isinstance(kwargs["model"], str)
            assert isinstance(kwargs["messages"], list)
            assert isinstance(kwargs["optional_params"], dict)
            assert isinstance(kwargs["litellm_params"], dict)
            assert isinstance(kwargs["start_time"], (datetime, type(None)))
            assert isinstance(kwargs["stream"], bool)
            assert isinstance(kwargs["user"], (str, type(None)))
            assert isinstance(kwargs["input"], (list, str, dict))
            assert isinstance(kwargs["api_key"], (str, type(None)))
            assert (
                isinstance(
                    kwargs["original_response"], (str, litellm.CustomStreamWrapper)
                )
                or inspect.isasyncgen(kwargs["original_response"])
                or inspect.iscoroutine(kwargs["original_response"])
                or kwargs["original_response"] == None
            )
            assert isinstance(kwargs["additional_args"], (dict, type(None)))
            assert isinstance(kwargs["log_event_type"], str)
        except Exception:
            print(f"Assertion Error: {traceback.format_exc()}")
            self.errors.append(traceback.format_exc())


# COMPLETION
## Test OpenAI + sync
def test_chat_openai_stream():
    try:
        customHandler = CompletionCustomHandler()
        litellm.callbacks = [customHandler]
        response = litellm.completion(
            model="gpt-3.5-turbo",
            messages=[{"role": "user", "content": "Hi 👋 - i'm sync openai"}],
        )
        ## test streaming
        response = litellm.completion(
            model="gpt-3.5-turbo",
            messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
            stream=True,
        )
        for chunk in response:
            continue
        ## test failure callback
        try:
            response = litellm.completion(
                model="gpt-3.5-turbo",
                messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
                api_key="my-bad-key",
                stream=True,
            )
            for chunk in response:
                continue
        except Exception:
            pass
        time.sleep(1)
        print(f"customHandler.errors: {customHandler.errors}")
        assert len(customHandler.errors) == 0
        litellm.callbacks = []
    except Exception as e:
        pytest.fail(f"An exception occurred: {str(e)}")


# test_chat_openai_stream()


## Test OpenAI + Async
@pytest.mark.asyncio
async def test_async_chat_openai_stream():
    try:
        customHandler = CompletionCustomHandler()
        litellm.callbacks = [customHandler]
        response = await litellm.acompletion(
            model="gpt-3.5-turbo",
            messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
        )
        ## test streaming
        response = await litellm.acompletion(
            model="gpt-3.5-turbo",
            messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
            stream=True,
        )
        async for chunk in response:
            continue

        await asyncio.sleep(1)
        ## test failure callback
        try:
            response = await litellm.acompletion(
                model="gpt-3.5-turbo",
                messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
                api_key="my-bad-key",
                stream=True,
            )
            async for chunk in response:
                continue
            await asyncio.sleep(1)
        except Exception:
            pass
        time.sleep(1)
        print(f"customHandler.errors: {customHandler.errors}")
        assert len(customHandler.errors) == 0
        litellm.callbacks = []
    except Exception as e:
        pytest.fail(f"An exception occurred: {str(e)}")


# asyncio.run(test_async_chat_openai_stream())


## Test Azure + sync
def test_chat_azure_stream():
    try:
        customHandler = CompletionCustomHandler()
        litellm.callbacks = [customHandler]
        response = litellm.completion(
            model="azure/gpt-4.1-mini",
            messages=[{"role": "user", "content": "Hi 👋 - i'm sync azure"}],
        )
        # test streaming
        response = litellm.completion(
            model="azure/gpt-4.1-mini",
            messages=[{"role": "user", "content": "Hi 👋 - i'm sync azure"}],
            stream=True,
        )
        for chunk in response:
            continue
        # test failure callback
        try:
            response = litellm.completion(
                model="azure/gpt-4.1-mini",
                messages=[{"role": "user", "content": "Hi 👋 - i'm sync azure"}],
                api_key="my-bad-key",
                stream=True,
            )
            for chunk in response:
                continue
        except Exception:
            pass
        time.sleep(1)
        print(f"customHandler.errors: {customHandler.errors}")
        assert len(customHandler.errors) == 0
        litellm.callbacks = []
    except Exception as e:
        pytest.fail(f"An exception occurred: {str(e)}")


# test_chat_azure_stream()


## Test Azure + Async
@pytest.mark.asyncio
async def test_async_chat_azure_stream():
    try:
        customHandler = CompletionCustomHandler()
        litellm.callbacks = [customHandler]
        response = await litellm.acompletion(
            model="azure/gpt-4.1-mini",
            messages=[{"role": "user", "content": "Hi 👋 - i'm async azure"}],
        )
        ## test streaming
        response = await litellm.acompletion(
            model="azure/gpt-4.1-mini",
            messages=[{"role": "user", "content": "Hi 👋 - i'm async azure"}],
            stream=True,
        )
        async for chunk in response:
            continue

        await asyncio.sleep(1)
        # test failure callback
        try:
            response = await litellm.acompletion(
                model="azure/gpt-4.1-mini",
                messages=[{"role": "user", "content": "Hi 👋 - i'm async azure"}],
                api_key="my-bad-key",
                stream=True,
            )
            async for chunk in response:
                continue
            await asyncio.sleep(1)
        except Exception:
            pass
        await asyncio.sleep(1)
        print(f"customHandler.errors: {customHandler.errors}")
        assert len(customHandler.errors) == 0
        litellm.callbacks = []
    except Exception as e:
        pytest.fail(f"An exception occurred: {str(e)}")


# asyncio.run(test_async_chat_azure_stream())


@pytest.mark.asyncio
async def test_async_chat_openai_stream_options():
    try:
        litellm.set_verbose = True
        customHandler = CompletionCustomHandler()
        litellm.callbacks = [customHandler]
        with patch.object(
            customHandler, "async_log_success_event", new=AsyncMock()
        ) as mock_client:
            response = await litellm.acompletion(
                model="gpt-3.5-turbo",
                messages=[{"role": "user", "content": "Hi 👋 - i'm async openai"}],
                stream=True,
                stream_options={"include_usage": True},
            )

            async for chunk in response:
                continue

            await asyncio.sleep(1)
            print("mock client args list=", mock_client.await_args_list)
            mock_client.assert_awaited_once()
    except Exception as e:
        pytest.fail(f"An exception occurred: {str(e)}")


# asyncio.run(test_async_chat_bedrock_stream())


## Test Sagemaker + Async
@pytest.mark.skip(reason="AWS Suspended Account")
@pytest.mark.asyncio
async def test_async_chat_sagemaker_stream():
    try:
        customHandler = CompletionCustomHandler()
        litellm.callbacks = [customHandler]
        response = await litellm.acompletion(
            model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
            messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
        )
        # test streaming
        response = await litellm.acompletion(
            model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
            messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
            stream=True,
        )
        print(f"response: {response}")
        async for chunk in response:
            print(f"chunk: {chunk}")
            continue
        ## test failure callback
        try:
            response = await litellm.acompletion(
                model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
                messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
                aws_region_name="my-bad-key",
                stream=True,
            )
            async for chunk in response:
                continue
        except Exception:
            pass
        time.sleep(1)
        print(f"customHandler.errors: {customHandler.errors}")
        assert len(customHandler.errors) == 0
        litellm.callbacks = []
    except Exception as e:
        pytest.fail(f"An exception occurred: {str(e)}")


## Test Vertex AI + Async
import json
import tempfile


def load_vertex_ai_credentials():
    # Define the path to the vertex_key.json file
    print("loading vertex ai credentials")
    filepath = os.path.dirname(os.path.abspath(__file__))
    vertex_key_path = filepath + "/vertex_key.json"

    # Read the existing content of the file or create an empty dictionary
    try:
        with open(vertex_key_path, "r") as file:
            # Read the file content
            print("Read vertexai file path")
            content = file.read()

            # If the file is empty or not valid JSON, create an empty dictionary
            if not content or not content.strip():
                service_account_key_data = {}
            else:
                # Attempt to load the existing JSON content
                file.seek(0)
                service_account_key_data = json.load(file)
    except FileNotFoundError:
        # If the file doesn't exist, create an empty dictionary
        service_account_key_data = {}

    # Update the service_account_key_data with environment variables
    private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
    private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
    private_key = private_key.replace("\\n", "\n")
    service_account_key_data["private_key_id"] = private_key_id
    service_account_key_data["private_key"] = private_key

    # Create a temporary file
    with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
        # Write the updated content to the temporary file
        json.dump(service_account_key_data, temp_file, indent=2)

    # Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
    os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)


@pytest.mark.skip(reason="Vertex AI Hanging")
@pytest.mark.asyncio
async def test_async_chat_vertex_ai_stream():
    try:
        load_vertex_ai_credentials()
        customHandler = CompletionCustomHandler()
        litellm.set_verbose = True
        litellm.callbacks = [customHandler]
        # test streaming
        response = await litellm.acompletion(
            model="gemini-pro",
            messages=[
                {
                    "role": "user",
                    "content": f"Hi 👋 - i'm async vertex_ai {uuid.uuid4()}",
                }
            ],
            stream=True,
        )
        print(f"response: {response}")
        async for chunk in response:
            print(f"chunk: {chunk}")
            continue
        await asyncio.sleep(10)
        print(f"customHandler.states: {customHandler.states}")
        assert (
            customHandler.states.count("async_success") == 1
        )  # pre, post, success, pre, post, failure
        assert len(customHandler.states) >= 3  # pre, post, success
    except Exception as e:
        pytest.fail(f"An exception occurred: {str(e)}")


# Text Completion


@pytest.mark.asyncio
@pytest.mark.skip(reason="temp-skip to see what else is failing")
async def test_async_text_completion_bedrock():
    try:
        customHandler = CompletionCustomHandler()
        litellm.callbacks = [customHandler]
        response = await litellm.atext_completion(
            model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
            prompt=["Hi 👋 - i'm async text completion bedrock"],
        )
        # test streaming
        response = await litellm.atext_completion(
            model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
            prompt=["Hi 👋 - i'm async text completion bedrock"],
            stream=True,
        )
        async for chunk in response:
            print(f"chunk: {chunk}")
            continue

        await asyncio.sleep(1)
        ## test failure callback
        try:
            response = await litellm.atext_completion(
                model="bedrock/",
                prompt=["Hi 👋 - i'm async text completion bedrock"],
                stream=True,
                api_key="my-bad-key",
            )
            async for chunk in response:
                continue

            await asyncio.sleep(1)
        except Exception:
            pass
        time.sleep(1)
        print(f"customHandler.errors: {customHandler.errors}")
        assert len(customHandler.errors) == 0
        litellm.callbacks = []
    except Exception as e:
        pytest.fail(f"An exception occurred: {str(e)}")


## Test OpenAI text completion + Async
@pytest.mark.asyncio
async def test_async_text_completion_openai_stream():
    try:
        customHandler = CompletionCustomHandler()
        litellm.callbacks = [customHandler]
        response = await litellm.atext_completion(
            model="gpt-3.5-turbo",
            prompt="Hi 👋 - i'm async text completion openai",
        )
        # test streaming
        response = await litellm.atext_completion(
            model="gpt-3.5-turbo",
            prompt="Hi 👋 - i'm async text completion openai",
            stream=True,
        )
        async for chunk in response:
            print(f"chunk: {chunk}")
            continue

        await asyncio.sleep(1)
        ## test failure callback
        try:
            response = await litellm.atext_completion(
                model="gpt-3.5-turbo",
                prompt="Hi 👋 - i'm async text completion openai",
                stream=True,
                api_key="my-bad-key",
            )
            async for chunk in response:
                continue

            await asyncio.sleep(1)
        except Exception:
            pass
        time.sleep(1)
        print(f"customHandler.errors: {customHandler.errors}")
        assert len(customHandler.errors) == 0
        litellm.callbacks = []
    except Exception as e:
        pytest.fail(f"An exception occurred: {str(e)}")


# EMBEDDING
## Test OpenAI + Async
@pytest.mark.asyncio
async def test_async_embedding_openai():
    try:
        customHandler_success = CompletionCustomHandler()
        customHandler_failure = CompletionCustomHandler()
        litellm.callbacks = [customHandler_success]
        response = await litellm.aembedding(
            model="text-embedding-ada-002",
            input=["good morning from litellm"],
        )
        await asyncio.sleep(1)
        print(f"customHandler_success.errors: {customHandler_success.errors}")
        print(f"customHandler_success.states: {customHandler_success.states}")
        assert len(customHandler_success.errors) == 0
        assert len(customHandler_success.states) == 3  # pre, post, success
        # test failure callback
        litellm.logging_callback_manager._reset_all_callbacks()
        litellm.callbacks = [customHandler_failure]
        try:
            response = await litellm.aembedding(
                model="text-embedding-ada-002",
                input=["good morning from litellm"],
                api_key="my-bad-key",
            )
        except Exception:
            pass
        await asyncio.sleep(1)
        print(f"customHandler_failure.errors: {customHandler_failure.errors}")
        print(f"customHandler_failure.states: {customHandler_failure.states}")
        assert len(customHandler_failure.errors) == 0
        assert len(customHandler_failure.states) == 3  # pre, post, failure
    except Exception as e:
        pytest.fail(f"An exception occurred: {str(e)}")


# asyncio.run(test_async_embedding_openai())


## Test Azure + Async
def test_amazing_sync_embedding():
    try:
        customHandler_success = CompletionCustomHandler()
        customHandler_failure = CompletionCustomHandler()
        litellm.callbacks = [customHandler_success]
        response = litellm.embedding(
            model="azure/text-embedding-ada-002", input=["good morning from litellm"]
        )
        print(f"customHandler_success.errors: {customHandler_success.errors}")
        print(f"customHandler_success.states: {customHandler_success.states}")
        time.sleep(2)
        assert len(customHandler_success.errors) == 0
        assert len(customHandler_success.states) == 3  # pre, post, success
        # test failure callback
        litellm.logging_callback_manager._reset_all_callbacks()
        litellm.callbacks = [customHandler_failure]
        try:
            response = litellm.embedding(
                model="azure/text-embedding-ada-002",
                input=["good morning from litellm"],
                api_key="my-bad-key",
            )
        except Exception:
            pass
        print(f"customHandler_failure.errors: {customHandler_failure.errors}")
        print(f"customHandler_failure.states: {customHandler_failure.states}")
        time.sleep(2)
        assert len(customHandler_failure.errors) == 1
        assert len(customHandler_failure.states) == 3  # pre, post, failure
    except Exception as e:
        pytest.fail(f"An exception occurred: {str(e)}")


## Test Azure + Async
@pytest.mark.asyncio
async def test_async_embedding_azure():
    try:
        customHandler_success = CompletionCustomHandler()
        customHandler_failure = CompletionCustomHandler()
        litellm.callbacks = [customHandler_success]
        response = await litellm.aembedding(
            model="azure/text-embedding-ada-002", input=["good morning from litellm"]
        )
        await asyncio.sleep(1)
        print(f"customHandler_success.errors: {customHandler_success.errors}")
        print(f"customHandler_success.states: {customHandler_success.states}")
        assert len(customHandler_success.errors) == 0
        assert len(customHandler_success.states) == 3  # pre, post, success
        # test failure callback
        litellm.logging_callback_manager._reset_all_callbacks()
        litellm.callbacks = [customHandler_failure]
        try:
            response = await litellm.aembedding(
                model="azure/text-embedding-ada-002",
                input=["good morning from litellm"],
                api_key="my-bad-key",
            )
        except Exception:
            pass
        await asyncio.sleep(1)
        print(f"customHandler_failure.errors: {customHandler_failure.errors}")
        print(f"customHandler_failure.states: {customHandler_failure.states}")
        assert len(customHandler_failure.errors) == 0
        assert len(customHandler_failure.states) == 3  # pre, post, success
    except Exception as e:
        pytest.fail(f"An exception occurred: {str(e)}")


# asyncio.run(test_async_embedding_azure())


## Test Bedrock + Async
@pytest.mark.asyncio
async def test_async_embedding_bedrock():
    try:
        customHandler_success = CompletionCustomHandler()
        customHandler_failure = CompletionCustomHandler()
        litellm.callbacks = [customHandler_success]
        litellm.set_verbose = True
        response = await litellm.aembedding(
            model="bedrock/cohere.embed-multilingual-v3",
            input=["good morning from litellm"],
            aws_region_name="us-east-1",
        )
        await asyncio.sleep(1)
        print(f"customHandler_success.errors: {customHandler_success.errors}")
        print(f"customHandler_success.states: {customHandler_success.states}")
        assert len(customHandler_success.errors) == 0
        assert len(customHandler_success.states) == 3  # pre, post, success
        # test failure callback
        litellm.logging_callback_manager._reset_all_callbacks()
        litellm.callbacks = [customHandler_failure]
        try:
            response = await litellm.aembedding(
                model="bedrock/cohere.embed-multilingual-v3",
                input=["good morning from litellm"],
                aws_region_name="my-bad-region",
            )
        except Exception:
            pass
        await asyncio.sleep(1)
        print(f"customHandler_failure.errors: {customHandler_failure.errors}")
        print(f"customHandler_failure.states: {customHandler_failure.states}")
        assert len(customHandler_failure.errors) == 0
        assert len(customHandler_failure.states) == 3  # pre, post, success
    except Exception as e:
        pytest.fail(f"An exception occurred: {str(e)}")


# Image Generation


## Test OpenAI + Sync
@pytest.mark.flaky(retries=3, delay=1)
def test_image_generation_openai():
    try:
        customHandler_success = CompletionCustomHandler()
        customHandler_failure = CompletionCustomHandler()
        litellm.callbacks = [customHandler_success]

        litellm.set_verbose = True

        response = litellm.image_generation(
            prompt="A cute baby sea otter",
            model="openai/dall-e-3",
            api_key=os.getenv("OPENAI_API_KEY"),
        )

        print(f"response: {response}")
        assert len(response.data) > 0

        print(f"customHandler_success.errors: {customHandler_success.errors}")
        print(f"customHandler_success.states: {customHandler_success.states}")
        time.sleep(2)
        assert len(customHandler_success.errors) == 0
        assert len(customHandler_success.states) == 3  # pre, post, success
        # test failure callback
        litellm.logging_callback_manager._reset_all_callbacks()
        litellm.callbacks = [customHandler_failure]
        try:
            response = litellm.image_generation(
                prompt="A cute baby sea otter",
                model="dall-e-2",
                api_key="my-bad-api-key",
            )
        except Exception:
            pass
        print(f"customHandler_failure.errors: {customHandler_failure.errors}")
        print(f"customHandler_failure.states: {customHandler_failure.states}")
        assert len(customHandler_failure.errors) == 0
        assert len(customHandler_failure.states) == 3  # pre, post, failure
    except litellm.RateLimitError as e:
        pass
    except litellm.ContentPolicyViolationError:
        pass  # OpenAI randomly raises these errors - skip when they occur
    except Exception as e:
        pytest.fail(f"An exception occurred - {str(e)}")


# test_image_generation_openai()
## Test OpenAI + Async

## Test Azure + Sync

## Test Azure + Async

##### PII REDACTION ######


def test_turn_off_message_logging():
    """
    If 'turn_off_message_logging' is true, assert no user request information is logged.
    """
    litellm.turn_off_message_logging = True

    # sync completion
    customHandler = CompletionCustomHandler()
    litellm.callbacks = [customHandler]

    _ = litellm.completion(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": "Hey, how's it going?"}],
        mock_response="Going well!",
    )

    time.sleep(2)
    assert len(customHandler.errors) == 0


##### VALID JSON ######


@pytest.mark.parametrize(
    "model",
    [
        "ft:gpt-3.5-turbo:my-org:custom_suffix:id"
    ],  # "gpt-3.5-turbo", "azure/gpt-4.1-mini",
)
@pytest.mark.parametrize(
    "turn_off_message_logging",
    [
        True,
    ],
)  # False
def test_standard_logging_payload(model, turn_off_message_logging):
    """
    Ensure valid standard_logging_payload is passed for logging calls to s3

    Motivation: provide a standard set of things that are logged to s3/gcs/future integrations across all llm calls
    """
    from litellm.types.utils import StandardLoggingPayload

    # sync completion
    customHandler = CompletionCustomHandler()
    litellm.callbacks = [customHandler]

    litellm.turn_off_message_logging = turn_off_message_logging

    with patch.object(
        customHandler, "log_success_event", new=MagicMock()
    ) as mock_client:
        _ = litellm.completion(
            model=model,
            messages=[{"role": "user", "content": "Hey, how's it going?"}],
            mock_response="Going well!",
        )

        time.sleep(2)
        mock_client.assert_called_once()

        print(
            f"mock_client_post.call_args: {mock_client.call_args.kwargs['kwargs'].keys()}"
        )
        assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
        assert (
            mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
            is not None
        )

        print(
            "Standard Logging Object - {}".format(
                mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
            )
        )

        keys_list = list(StandardLoggingPayload.__annotations__.keys())

        for k in keys_list:
            assert (
                k in mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
            )

        ## json serializable
        json_str_payload = json.dumps(
            mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
        )
        json.loads(json_str_payload)

        ## response cost
        assert (
            mock_client.call_args.kwargs["kwargs"]["standard_logging_object"][
                "response_cost"
            ]
            > 0
        )
        assert (
            mock_client.call_args.kwargs["kwargs"]["standard_logging_object"][
                "model_map_information"
            ]["model_map_value"]
            is not None
        )

        ## turn off message logging
        slobject: StandardLoggingPayload = mock_client.call_args.kwargs["kwargs"][
            "standard_logging_object"
        ]
        if turn_off_message_logging:
            print("checks redacted-by-litellm")
            assert "redacted-by-litellm" == slobject["messages"][0]["content"]
            # response is a full ModelResponse dict (choices format) since d84e5e381acf
            response = slobject["response"]
            assert response["choices"][0]["message"]["content"] == "redacted-by-litellm"
            assert response["choices"][0]["message"].get("audio") is None


@pytest.mark.parametrize(
    "stream",
    [True, False],
)
@pytest.mark.parametrize(
    "turn_off_message_logging",
    [
        True,
    ],
)  # False
def test_standard_logging_payload_audio(turn_off_message_logging, stream):
    """
    Ensure valid standard_logging_payload is passed for logging calls to s3

    Motivation: provide a standard set of things that are logged to s3/gcs/future integrations across all llm calls
    """
    from litellm.types.utils import StandardLoggingPayload

    # sync completion
    customHandler = CompletionCustomHandler()
    litellm.callbacks = [customHandler]

    litellm.turn_off_message_logging = turn_off_message_logging

    with patch.object(
        customHandler, "log_success_event", new=MagicMock()
    ) as mock_client:
        try:
            response = litellm.completion(
                model="gpt-4o-audio-preview",
                modalities=["text", "audio"],
                audio={"voice": "alloy", "format": "pcm16"},
                messages=[
                    {"role": "user", "content": "response in 1 word - yes or no"}
                ],
                stream=stream,
            )
        except Exception as e:
            if "openai-internal" in str(e):
                pytest.skip("Skipping test due to openai-internal error")

        if stream:
            for chunk in response:
                continue

        time.sleep(2)
        mock_client.assert_called()

        print(
            f"mock_client_post.call_args: {mock_client.call_args.kwargs['kwargs'].keys()}"
        )
        assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
        assert (
            mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
            is not None
        )

        print(
            "Standard Logging Object - {}".format(
                mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
            )
        )

        keys_list = list(StandardLoggingPayload.__annotations__.keys())

        for k in keys_list:
            assert (
                k in mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
            )

        ## json serializable
        json_str_payload = json.dumps(
            mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
        )
        json.loads(json_str_payload)

        ## response cost
        assert (
            mock_client.call_args.kwargs["kwargs"]["standard_logging_object"][
                "response_cost"
            ]
            > 0
        )
        assert (
            mock_client.call_args.kwargs["kwargs"]["standard_logging_object"][
                "model_map_information"
            ]["model_map_value"]
            is not None
        )

        ## turn off message logging
        slobject: StandardLoggingPayload = mock_client.call_args.kwargs["kwargs"][
            "standard_logging_object"
        ]
        if turn_off_message_logging:
            print("checks redacted-by-litellm")
            assert "redacted-by-litellm" == slobject["messages"][0]["content"]
            # response is a full ModelResponse dict (choices format) since d84e5e381acf
            response = slobject["response"]
            assert response["choices"][0]["message"]["content"] == "redacted-by-litellm"
            assert response["choices"][0]["message"].get("audio") is None


@pytest.mark.skip(reason="Works locally. Flaky on ci/cd")
def test_aaastandard_logging_payload_cache_hit():
    from litellm.types.utils import StandardLoggingPayload

    # sync completion

    litellm.cache = Cache()

    _ = litellm.completion(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": "Hey, how's it going?"}],
        caching=True,
    )

    customHandler = CompletionCustomHandler()
    litellm.callbacks = [customHandler]
    litellm.success_callback = []

    with patch.object(
        customHandler, "log_success_event", new=MagicMock()
    ) as mock_client:
        _ = litellm.completion(
            model="gpt-3.5-turbo",
            messages=[{"role": "user", "content": "Hey, how's it going?"}],
            caching=True,
        )

        time.sleep(2)
        mock_client.assert_called_once()

        assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
        assert (
            mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
            is not None
        )

        standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
            "kwargs"
        ]["standard_logging_object"]

        assert standard_logging_object["cache_hit"] is True
        assert standard_logging_object["response_cost"] == 0
        assert standard_logging_object["saved_cache_cost"] > 0


@pytest.mark.parametrize(
    "turn_off_message_logging",
    [False, True],
)  # False
def test_logging_async_cache_hit_sync_call(turn_off_message_logging):
    from litellm.types.utils import StandardLoggingPayload

    litellm.turn_off_message_logging = turn_off_message_logging

    litellm.cache = Cache()

    response = litellm.completion(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": "Hey, how's it going?"}],
        caching=True,
        stream=True,
    )
    for chunk in response:
        print(chunk)

    time.sleep(3)
    customHandler = CompletionCustomHandler()
    litellm.callbacks = [customHandler]
    litellm.success_callback = []

    with patch.object(
        customHandler, "log_success_event", new=MagicMock()
    ) as mock_client:
        resp = litellm.completion(
            model="gpt-3.5-turbo",
            messages=[{"role": "user", "content": "Hey, how's it going?"}],
            caching=True,
            stream=True,
        )

        for chunk in resp:
            print(chunk)

        time.sleep(2)
        mock_client.assert_called_once()

        assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
        assert (
            mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
            is not None
        )

        standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
            "kwargs"
        ]["standard_logging_object"]

        assert standard_logging_object["cache_hit"] is True
        assert standard_logging_object["response_cost"] == 0
        assert standard_logging_object["saved_cache_cost"] > 0

        if turn_off_message_logging:
            print("checks redacted-by-litellm")
            assert (
                "redacted-by-litellm"
                == standard_logging_object["messages"][0]["content"]
            )
            assert {"text": "redacted-by-litellm"} == standard_logging_object[
                "response"
            ]


def test_logging_standard_payload_failure_call():
    from litellm.types.utils import StandardLoggingPayload

    customHandler = CompletionCustomHandler()
    litellm.callbacks = [customHandler]

    with patch.object(
        customHandler, "log_failure_event", new=MagicMock()
    ) as mock_client:
        try:
            resp = litellm.completion(
                model="gpt-3.5-turbo",
                messages=[{"role": "user", "content": "Hey, how's it going?"}],
                api_key="my-bad-api-key",
            )
        except litellm.AuthenticationError:
            pass

        mock_client.assert_called_once()

        assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
        assert (
            mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
            is not None
        )

        standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
            "kwargs"
        ]["standard_logging_object"]
        assert "additional_headers" in standard_logging_object["hidden_params"]


@pytest.mark.parametrize("stream", [False, True])
def test_logging_standard_payload_llm_headers(stream):
    from litellm.types.utils import StandardLoggingPayload

    # sync completion
    customHandler = CompletionCustomHandler()
    litellm.callbacks = [customHandler]

    with patch.object(
        customHandler, "log_success_event", new=MagicMock()
    ) as mock_client:

        resp = litellm.completion(
            model="gpt-3.5-turbo",
            messages=[{"role": "user", "content": "Hey, how's it going?"}],
            stream=stream,
        )

        if stream:
            for chunk in resp:
                continue

        time.sleep(2)
        mock_client.assert_called()

        standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
            "kwargs"
        ]["standard_logging_object"]

        print(standard_logging_object["hidden_params"]["additional_headers"])


def test_logging_key_masking_gemini():
    customHandler = CompletionCustomHandler()
    litellm.callbacks = [customHandler]
    litellm.success_callback = []

    with patch.object(
        customHandler, "log_pre_api_call", new=MagicMock()
    ) as mock_client:
        try:
            resp = litellm.completion(
                model="gemini/gemini-1.5-pro",
                messages=[{"role": "user", "content": "Hey, how's it going?"}],
                api_key="LEAVE_ONLY_LAST_4_CHAR_UNMASKED_THIS_PART",
            )
        except litellm.AuthenticationError:
            pass

        mock_client.assert_called()

        print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
        assert (
            "LEAVE_ONLY_LAST_4_CHAR_UNMASKED_THIS_PART"
            not in mock_client.call_args.kwargs["kwargs"]["litellm_params"]["api_base"]
        )
        key = mock_client.call_args.kwargs["kwargs"]["litellm_params"]["api_base"]
        trimmed_key = key.split("key=")[1]
        trimmed_key = trimmed_key.replace("*", "")
        assert "PART" == trimmed_key


@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_standard_logging_payload_stream_usage(sync_mode):
    """
    Even if stream_options is not provided, correct usage should be logged
    """
    from litellm.types.utils import StandardLoggingPayload
    from litellm.main import stream_chunk_builder

    stream = True
    try:
        # sync completion
        customHandler = CompletionCustomHandler()
        litellm.callbacks = [customHandler]

        if sync_mode:
            patch_event = "log_success_event"
            return_val = MagicMock()
        else:
            patch_event = "async_log_success_event"
            return_val = AsyncMock()

        with patch.object(customHandler, patch_event, new=return_val) as mock_client:
            if sync_mode:
                resp = litellm.completion(
                    model="anthropic/claude-sonnet-4-5-20250929",
                    messages=[{"role": "user", "content": "Hey, how's it going?"}],
                    stream=stream,
                )

                chunks = []
                for chunk in resp:
                    chunks.append(chunk)
                time.sleep(2)
            else:
                resp = await litellm.acompletion(
                    model="anthropic/claude-sonnet-4-5-20250929",
                    messages=[{"role": "user", "content": "Hey, how's it going?"}],
                    stream=stream,
                )

                chunks = []
                async for chunk in resp:
                    chunks.append(chunk)
                await asyncio.sleep(2)

            mock_client.assert_called_once()

            standard_logging_object: StandardLoggingPayload = (
                mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
            )

            built_response = stream_chunk_builder(chunks=chunks)
            print(f"built_response: {built_response}")
            assert (
                built_response.usage.total_tokens
                == standard_logging_object["total_tokens"]
            )
            print(f"standard_logging_object usage: {built_response.usage}")
    except litellm.InternalServerError:
        pass


def test_standard_logging_retries():
    """
    know if a request was retried.
    """
    from litellm.types.utils import StandardLoggingPayload
    from litellm.router import Router

    customHandler = CompletionCustomHandler()
    litellm.callbacks = [customHandler]

    router = Router(
        model_list=[
            {
                "model_name": "gpt-3.5-turbo",
                "litellm_params": {
                    "model": "openai/gpt-3.5-turbo",
                    "api_key": "test-api-key",
                },
            }
        ]
    )

    with patch.object(
        customHandler, "log_failure_event", new=MagicMock()
    ) as mock_client:
        try:
            router.completion(
                model="gpt-3.5-turbo",
                messages=[{"role": "user", "content": "Hey, how's it going?"}],
                num_retries=1,
                mock_response="litellm.RateLimitError",
            )
        except litellm.RateLimitError:
            pass

        assert mock_client.call_count == 2
        assert (
            mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
                "trace_id"
            ]
            is not None
        )
        assert (
            mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
                "trace_id"
            ]
            == mock_client.call_args_list[1].kwargs["kwargs"][
                "standard_logging_object"
            ]["trace_id"]
        )


@pytest.mark.parametrize("disable_no_log_param", [True, False])
def test_litellm_logging_no_log_param(monkeypatch, disable_no_log_param):
    monkeypatch.setattr(litellm, "global_disable_no_log_param", disable_no_log_param)
    from litellm.litellm_core_utils.litellm_logging import Logging

    litellm.success_callback = ["langfuse"]
    litellm_call_id = "my-unique-call-id"
    litellm_logging_obj = Logging(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": "hi"}],
        stream=False,
        call_type="acompletion",
        litellm_call_id=litellm_call_id,
        start_time=datetime.now(),
        function_id="1234",
    )

    should_run = litellm_logging_obj.should_run_callback(
        callback="langfuse",
        litellm_params={"no-log": True},
        event_hook="success_handler",
    )

    if disable_no_log_param:
        assert should_run is True
    else:
        assert should_run is False
