import io
import os
import sys


sys.path.insert(0, os.path.abspath("../.."))

import asyncio
import gzip
import json
import logging
import time
from unittest.mock import AsyncMock, patch

import pytest

import litellm
from litellm import completion
from litellm._logging import verbose_logger
from litellm.proxy.utils import log_db_metrics, ServiceTypes
from datetime import datetime
import httpx
from prisma.errors import ClientNotConnectedError


# Test async function to decorate
@log_db_metrics
async def sample_db_function(*args, **kwargs):
    return "success"


@log_db_metrics
async def sample_proxy_function(*args, **kwargs):
    return "success"


@pytest.mark.asyncio
async def test_log_db_metrics_success():
    # Mock the proxy_logging_obj
    with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
        # Setup mock
        mock_proxy_logging.service_logging_obj.async_service_success_hook = AsyncMock()

        # Call the decorated function
        result = await sample_db_function(parent_otel_span="test_span")

        # Assertions
        assert result == "success"
        mock_proxy_logging.service_logging_obj.async_service_success_hook.assert_called_once()
        call_args = (
            mock_proxy_logging.service_logging_obj.async_service_success_hook.call_args[
                1
            ]
        )

        assert call_args["service"] == ServiceTypes.DB
        assert call_args["call_type"] == "sample_db_function"
        assert call_args["parent_otel_span"] == "test_span"
        assert isinstance(call_args["duration"], float)
        assert isinstance(call_args["start_time"], datetime)
        assert isinstance(call_args["end_time"], datetime)
        assert "function_name" in call_args["event_metadata"]


@pytest.mark.asyncio
async def test_log_db_metrics_duration():
    # Mock the proxy_logging_obj
    with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
        # Setup mock
        mock_proxy_logging.service_logging_obj.async_service_success_hook = AsyncMock()

        # Add a delay to the function to test duration
        @log_db_metrics
        async def delayed_function(**kwargs):
            await asyncio.sleep(1)  # 1 second delay
            return "success"

        # Call the decorated function
        start = time.time()
        result = await delayed_function(parent_otel_span="test_span")
        end = time.time()

        # Get the actual duration
        actual_duration = end - start

        # Get the logged duration from the mock call
        call_args = (
            mock_proxy_logging.service_logging_obj.async_service_success_hook.call_args[
                1
            ]
        )
        logged_duration = call_args["duration"]

        # Assert the logged duration is approximately equal to actual duration (within 0.1 seconds)
        assert abs(logged_duration - actual_duration) < 0.1
        assert result == "success"


@pytest.mark.asyncio
async def test_log_db_metrics_failure():
    """
    should log a failure if a prisma error is raised
    """
    # Mock the proxy_logging_obj
    from prisma.errors import ClientNotConnectedError

    with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
        # Setup mock
        mock_proxy_logging.service_logging_obj.async_service_failure_hook = AsyncMock()

        # Create a failing function
        @log_db_metrics
        async def failing_function(**kwargs):
            raise ClientNotConnectedError()

        # Call the decorated function and expect it to raise
        with pytest.raises(ClientNotConnectedError) as exc_info:
            await failing_function(parent_otel_span="test_span")

        # Assertions
        assert "Client is not connected to the query engine" in str(exc_info.value)
        mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_called_once()
        call_args = (
            mock_proxy_logging.service_logging_obj.async_service_failure_hook.call_args[
                1
            ]
        )

        assert call_args["service"] == ServiceTypes.DB
        assert call_args["call_type"] == "failing_function"
        assert call_args["parent_otel_span"] == "test_span"
        assert isinstance(call_args["duration"], float)
        assert isinstance(call_args["start_time"], datetime)
        assert isinstance(call_args["end_time"], datetime)
        assert isinstance(call_args["error"], ClientNotConnectedError)


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "exception,should_log",
    [
        (ValueError("Generic error"), False),
        (KeyError("Missing key"), False),
        (TypeError("Type error"), False),
        (httpx.ConnectError("Failed to connect"), True),
        (httpx.TimeoutException("Request timed out"), True),
        (ClientNotConnectedError(), True),  # Prisma error
    ],
)
async def test_log_db_metrics_failure_error_types(exception, should_log):
    """
    Why Test?
    Users were seeing that non-DB errors were being logged as DB Service Failures
    Example a failure to read a value from cache was being logged as a DB Service Failure


    Parameterized test to verify:
    - DB-related errors (Prisma, httpx) are logged as service failures
    - Non-DB errors (ValueError, KeyError, etc.) are not logged
    """
    with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
        mock_proxy_logging.service_logging_obj.async_service_failure_hook = AsyncMock()

        @log_db_metrics
        async def failing_function(**kwargs):
            raise exception

        # Call the function and expect it to raise the exception
        with pytest.raises(type(exception)):
            await failing_function(parent_otel_span="test_span")

        if should_log:
            # Assert failure was logged for DB-related errors
            mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_called_once()
            call_args = mock_proxy_logging.service_logging_obj.async_service_failure_hook.call_args[
                1
            ]
            assert call_args["service"] == ServiceTypes.DB
            assert call_args["call_type"] == "failing_function"
            assert call_args["parent_otel_span"] == "test_span"
            assert isinstance(call_args["duration"], float)
            assert isinstance(call_args["start_time"], datetime)
            assert isinstance(call_args["end_time"], datetime)
            assert isinstance(call_args["error"], type(exception))
        else:
            # Assert failure was NOT logged for non-DB errors
            mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_not_called()


@pytest.mark.asyncio
async def test_dd_log_db_spend_failure_metrics():
    from litellm._service_logger import ServiceLogging
    from litellm.integrations.datadog.datadog import DataDogLogger

    dd_logger = DataDogLogger()
    with patch.object(dd_logger, "async_service_failure_hook", new_callable=AsyncMock):
        service_logging_obj = ServiceLogging()

        litellm.service_callback = [dd_logger]

        await service_logging_obj.async_service_failure_hook(
            service=ServiceTypes.DB,
            call_type="test_call_type",
            error="test_error",
            duration=1.0,
        )
