"""
Tests for DeepInfra rerank functionality following repository patterns.
"""
import asyncio
import json
import os
import sys
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

# Add litellm to path
sys.path.insert(0, os.path.abspath("../../../.."))
import litellm


def assert_response_shape(response, custom_llm_provider):
    """Helper function to validate response structure."""
    assert hasattr(response, "id")
    assert hasattr(response, "results")
    assert hasattr(response, "meta")
    assert isinstance(response.results, list)

    for result in response.results:
        assert "index" in result
        assert "relevance_score" in result
        assert isinstance(result["index"], int)
        assert isinstance(result["relevance_score"], (int, float))

    # Check meta structure
    assert "tokens" in response.meta
    assert "billed_units" in response.meta
    assert "input_tokens" in response.meta["tokens"]
    assert "total_tokens" in response.meta["billed_units"]


@pytest.mark.parametrize("sync_mode", [True, False])
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
@patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post")
def test_basic_rerank_deepinfra(mock_sync_post, mock_async_post, sync_mode):
    """Test basic DeepInfra rerank functionality."""
    # Mock response data that matches DeepInfra API format
    mock_response_data = {
        "scores": [0.9, 0.1],
        "input_tokens": 25,
        "request_id": "deepinfra-request-123",
        "inference_status": {
            "status": "success",
            "runtime_ms": 150,
            "cost": 0.0001,
            "tokens_generated": 0,
            "tokens_input": 25,
        },
    }

    def return_val():
        return mock_response_data

    api_key = "test_deepinfra_api_key"
    api_base = "https://api.deepinfra.com"

    if sync_mode:
        # Create mock response object for sync
        mock_response = MagicMock()
        mock_response.json = return_val
        mock_response.status_code = 200
        mock_response.headers = {"content-type": "application/json"}
        mock_response.text = json.dumps(mock_response_data)
        mock_sync_post.return_value = mock_response

        response = litellm.rerank(
            model="deepinfra/Qwen/Qwen3-Reranker-0.6B",
            query="hello",
            documents=["hello", "world"],
            top_n=2,
            custom_llm_provider="deepinfra",
            api_key=api_key,
            api_base=api_base,
        )
        mock_sync_post.assert_called_once()
    else:
        # Create mock response object for async
        mock_response = AsyncMock()

        def return_val():
            return mock_response_data

        mock_response.json = return_val
        mock_response.status_code = 200
        mock_response.headers = {"content-type": "application/json"}
        mock_response.text = json.dumps(mock_response_data)
        mock_async_post.return_value = mock_response

        response = asyncio.run(
            litellm.arerank(
                model="deepinfra/Qwen/Qwen3-Reranker-0.6B",
                query="hello",
                documents=["hello", "world"],
                top_n=2,
                custom_llm_provider="deepinfra",
                api_key=api_key,
                api_base=api_base,
            )
        )
        mock_async_post.assert_called_once()

    # Verify response structure
    assert response.id == "deepinfra-request-123"
    assert response.results is not None
    assert len(response.results) == 2
    assert response.results[0]["index"] == 0
    assert response.results[0]["relevance_score"] == 0.9
    assert response.results[1]["index"] == 1
    assert response.results[1]["relevance_score"] == 0.1

    # Verify metadata
    assert response.meta["tokens"]["input_tokens"] == 25
    assert response.meta["billed_units"]["total_tokens"] == 25

    # Verify hidden params specific to DeepInfra
    assert response._hidden_params["status"] == "success"
    assert response._hidden_params["runtime_ms"] == 150
    assert response._hidden_params["cost"] == 0.0001
    # Note: The model name is processed and the 'deepinfra/' prefix is removed
    assert response._hidden_params["model"] == "Qwen/Qwen3-Reranker-0.6B"

    assert_response_shape(response, custom_llm_provider="deepinfra")


@patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post")
def test_deepinfra_rerank_with_queries_param(mock_post):
    """Test DeepInfra rerank with multiple queries parameter."""
    mock_response_data = {
        "scores": [0.8, 0.6, 0.2],
        "input_tokens": 35,
        "request_id": "deepinfra-multi-query-123",
        "inference_status": {"status": "success", "runtime_ms": 200},
    }

    def return_val():
        return mock_response_data

    mock_response = MagicMock()
    mock_response.json = return_val
    mock_response.status_code = 200
    mock_response.headers = {"content-type": "application/json"}
    mock_response.text = json.dumps(mock_response_data)
    mock_post.return_value = mock_response

    response = litellm.rerank(
        model="deepinfra/Qwen/Qwen3-Reranker-4B",
        query="hello",
        documents=["hello", "world", "test"],
        queries=["hello", "hi there"],  # DeepInfra specific param
        custom_llm_provider="deepinfra",
        api_key="test_key",
        api_base="https://api.deepinfra.com",
    )

    mock_post.assert_called_once()
    # Verify that queries parameter was passed in request
    call_data = json.loads(mock_post.call_args.kwargs["data"])
    assert "queries" in call_data
    assert call_data["queries"] == ["hello", "hi there"]

    assert response.results is not None
    assert len(response.results) == 3


@patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post")
def test_deepinfra_rerank_with_service_tier(mock_post):
    """Test DeepInfra rerank with service_tier parameter."""
    mock_response_data = {
        "scores": [0.95, 0.75],
        "input_tokens": 30,
        "request_id": "deepinfra-premium-123",
    }

    def return_val():
        return mock_response_data

    mock_response = MagicMock()
    mock_response.json = return_val
    mock_response.status_code = 200
    mock_response.headers = {"content-type": "application/json"}
    mock_response.text = json.dumps(mock_response_data)
    mock_post.return_value = mock_response

    response = litellm.rerank(
        model="deepinfra/Qwen/Qwen3-Reranker-8B",
        query="premium search",
        documents=["doc1", "doc2"],
        service_tier="premium",  # DeepInfra specific param
        custom_llm_provider="deepinfra",
        api_key="test_key",
        api_base="https://api.deepinfra.com",
    )

    mock_post.assert_called_once()

    # Verify URL
    call_url = mock_post.call_args.kwargs["url"]
    assert "api.deepinfra.com/inference/Qwen/Qwen3-Reranker-8B" in call_url

    # Verify request contains service_tier
    call_data = json.loads(mock_post.call_args.kwargs["data"])
    assert call_data["service_tier"] == "premium"

    assert response.results is not None


@patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post")
def test_deepinfra_rerank_request_format(mock_post):
    """Test that the request is properly formatted for DeepInfra API."""
    mock_response_data = {"scores": [0.9, 0.1], "input_tokens": 20}

    def return_val():
        return mock_response_data

    mock_response = MagicMock()
    mock_response.json = return_val
    mock_response.status_code = 200
    mock_response.headers = {"content-type": "application/json"}
    mock_response.text = json.dumps(mock_response_data)
    mock_post.return_value = mock_response

    response = litellm.rerank(
        model="deepinfra/Qwen/Qwen3-Reranker-0.6B",
        query="test query",
        documents=["doc1", "doc2"],
        custom_llm_provider="deepinfra",
        api_key="test_key",
        api_base="https://api.deepinfra.com",
        instruction="custom instruction",
        webhook="https://webhook.example.com",
    )

    mock_post.assert_called_once()

    # Verify URL format
    call_url = mock_post.call_args.kwargs["url"]
    assert call_url == "https://api.deepinfra.com/inference/Qwen/Qwen3-Reranker-0.6B"

    # Verify headers
    headers = mock_post.call_args.kwargs["headers"]
    assert headers["Authorization"] == "Bearer test_key"
    assert headers["accept"] == "application/json"
    assert headers["content-type"] == "application/json"

    # Verify request body format
    request_data = json.loads(mock_post.call_args.kwargs["data"])
    assert request_data["queries"] == [
        "test query",
        "test query",
    ]  # DeepInfra requires queries to match documents length
    assert request_data["documents"] == ["doc1", "doc2"]
    assert request_data["instruction"] == "custom instruction"
    assert request_data["webhook"] == "https://webhook.example.com"

    assert response.results is not None


@patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post")
def test_deepinfra_rerank_error_handling(mock_post):
    """Test DeepInfra rerank error handling."""
    error_response = {"detail": {"error": "Invalid API key"}}

    def return_val():
        return error_response

    mock_response = MagicMock()
    mock_response.status_code = 401
    mock_response.json = return_val
    mock_response.text = json.dumps(error_response)
    mock_response.headers = {"content-type": "application/json"}
    mock_post.return_value = mock_response

    # The current implementation handles errors gracefully, so we expect a successful response
    # with the error information in the hidden params
    response = litellm.rerank(
        model="deepinfra/Qwen/Qwen3-Reranker-0.6B",
        query="hello",
        documents=["hello", "world"],
        custom_llm_provider="deepinfra",
        api_key="invalid_key",
        api_base="https://api.deepinfra.com",
    )

    # Verify that the response contains error information
    assert (
        response._hidden_params["status"] == "unknown"
    )  # Default status when error occurs


def test_deepinfra_rerank_models():
    """Test that DeepInfra Qwen rerank models are recognized."""
    # These should not raise errors during model validation
    models = [
        "deepinfra/Qwen/Qwen3-Reranker-0.6B",
        "deepinfra/Qwen/Qwen3-Reranker-4B",
        "deepinfra/Qwen/Qwen3-Reranker-8B",
    ]

    for model in models:
        # This should not raise any validation errors
        try:
            litellm.get_llm_provider(model=model)
        except Exception as e:
            # We expect this to potentially fail due to missing api_base/key
            # but the model format should be recognized
            assert "api_base" in str(e) or "API key" in str(
                e
            ), f"Unexpected error for model {model}: {e}"


@patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post")
def test_deepinfra_rerank_minimal_response(mock_post):
    """Test handling of minimal DeepInfra response."""
    # Minimal response with just scores
    mock_response_data = {"scores": [0.7, 0.3]}

    def return_val():
        return mock_response_data

    mock_response = MagicMock()
    mock_response.json = return_val
    mock_response.status_code = 200
    mock_response.headers = {"content-type": "application/json"}
    mock_response.text = json.dumps(mock_response_data)
    mock_post.return_value = mock_response

    response = litellm.rerank(
        model="deepinfra/Qwen/Qwen3-Reranker-0.6B",
        query="hello",
        documents=["hello", "world"],
        custom_llm_provider="deepinfra",
        api_key="test_key",
        api_base="https://api.deepinfra.com",
    )

    # Should handle minimal response gracefully
    assert response.results is not None
    assert len(response.results) == 2
    assert response.results[0]["relevance_score"] == 0.7
    assert response.results[1]["relevance_score"] == 0.3

    # Should have default values for missing fields
    assert response.meta["tokens"]["input_tokens"] == 0  # Default when missing
    assert response._hidden_params["status"] == "unknown"  # Default when missing
