# stdlib imports
import json
import os
import sys
import time
from unittest.mock import patch

import pytest

# third party imports
from click.testing import CliRunner

sys.path.insert(
    0, os.path.abspath("../../..")
)  # Adds the parent directory to the system path


# local imports
from litellm.proxy.client.cli import cli
from litellm.proxy.client.cli.commands.models import (
    format_cost_per_1k_tokens,
    format_iso_datetime_str,
    format_timestamp,
)


@pytest.fixture
def mock_client():
    """Fixture to create a mock client with common setup"""
    with patch("litellm.proxy.client.cli.commands.models.Client") as MockClient:
        yield MockClient


@pytest.fixture
def cli_runner():
    """Fixture for Click CLI runner"""
    return CliRunner()


@pytest.fixture(autouse=True)
def mock_env():
    """Fixture to set up environment variables for all tests"""
    with patch.dict(
        os.environ,
        {
            "LITELLM_PROXY_URL": "http://localhost:4000",
            "LITELLM_PROXY_API_KEY": "sk-test",
        },
    ):
        yield


@pytest.fixture
def mock_models_list(mock_client):
    """Fixture to set up common mocking pattern for models list tests"""
    mock_client.return_value.models.list.return_value = [
        {
            "id": "model-123",
            "object": "model",
            "created": 1699848889,
            "owned_by": "organization-123",
        },
        {
            "id": "model-456",
            "object": "model",
            "created": 1699848890,
            "owned_by": "organization-456",
        },
    ]

    mock_client.assert_not_called()  # Ensure clean slate
    return mock_client


@pytest.fixture
def mock_models_info(mock_client):
    """Fixture to set up models info mock"""
    mock_client.return_value.models.info.return_value = [
        {
            "model_name": "gpt-4",
            "litellm_params": {"model": "gpt-4", "litellm_credential_name": "openai-1"},
            "model_info": {
                "id": "model-123",
                "created_at": "2025-04-29T21:31:43.843000+00:00",
                "updated_at": "2025-04-29T21:31:43.843000+00:00",
                "input_cost_per_token": 0.00001,
                "output_cost_per_token": 0.00002,
            },
        }
    ]

    mock_client.assert_not_called()
    return mock_client


@pytest.fixture
def force_utc_tz():
    """Fixture to force UTC timezone for tests that depend on system TZ."""
    old_tz = os.environ.get("TZ")
    os.environ["TZ"] = "UTC"
    if hasattr(time, "tzset"):
        time.tzset()
    yield
    # Restore previous TZ
    if old_tz is not None:
        os.environ["TZ"] = old_tz
    else:
        if "TZ" in os.environ:
            del os.environ["TZ"]
    if hasattr(time, "tzset"):
        time.tzset()


def test_models_list_json_format(mock_models_list, cli_runner):
    """Test the models list command with JSON output format"""
    # Run the command
    result = cli_runner.invoke(cli, ["models", "list", "--format", "json"])

    # Check that the command succeeded
    assert result.exit_code == 0

    # Parse the output and verify it matches our mock data
    output_data = json.loads(result.output)
    assert output_data == mock_models_list.return_value.models.list.return_value

    # Verify the client was called correctly
    mock_models_list.assert_called_once_with(
        base_url="http://localhost:4000", api_key="sk-test"
    )
    mock_models_list.return_value.models.list.assert_called_once()


def test_models_list_table_format(mock_models_list, cli_runner):
    """Test the models list command with table output format"""
    # Run the command
    result = cli_runner.invoke(cli, ["models", "list"])

    # Check that the command succeeded
    assert result.exit_code == 0

    # Verify the output contains expected table elements
    assert "ID" in result.output
    assert "Object" in result.output
    assert "Created" in result.output
    assert "Owned By" in result.output
    assert "model-123" in result.output
    assert "organization-123" in result.output
    assert format_timestamp(1699848889) in result.output

    # Verify the client was called correctly
    mock_models_list.assert_called_once_with(
        base_url="http://localhost:4000", api_key="sk-test"
    )
    mock_models_list.return_value.models.list.assert_called_once()


def test_models_list_with_base_url(mock_models_list, cli_runner):
    """Test the models list command with custom base URL overriding env var"""
    custom_base_url = "http://custom.server:8000"

    # Run the command with custom base URL
    result = cli_runner.invoke(cli, ["--base-url", custom_base_url, "models", "list"])

    # Check that the command succeeded
    assert result.exit_code == 0

    # Verify the client was created with the custom base URL (overriding env var)
    mock_models_list.assert_called_once_with(
        base_url=custom_base_url,
        api_key="sk-test",  # Should still use env var for API key
    )


def test_models_list_with_api_key(mock_models_list, cli_runner):
    """Test the models list command with API key overriding env var"""
    custom_api_key = "custom-test-key"

    # Run the command with custom API key
    result = cli_runner.invoke(cli, ["--api-key", custom_api_key, "models", "list"])

    # Check that the command succeeded
    assert result.exit_code == 0

    # Verify the client was created with the custom API key (overriding env var)
    mock_models_list.assert_called_once_with(
        base_url="http://localhost:4000",  # Should still use env var for base URL
        api_key=custom_api_key,
    )


def test_models_list_error_handling(mock_client, cli_runner):
    """Test error handling in the models list command"""
    # Configure mock to raise an exception
    mock_client.return_value.models.list.side_effect = Exception("API Error")

    # Run the command
    result = cli_runner.invoke(cli, ["models", "list"])

    # Check that the command failed
    assert result.exit_code != 0
    assert "API Error" in str(result.exception)

    # Verify the client was created with env var values
    mock_client.assert_called_once_with(
        base_url="http://localhost:4000", api_key="sk-test"
    )


def test_models_info_json_format(mock_models_info, cli_runner):
    """Test the models info command with JSON output format"""
    # Run the command
    result = cli_runner.invoke(cli, ["models", "info", "--format", "json"])

    # Check that the command succeeded
    assert result.exit_code == 0

    # Parse the output and verify it matches our mock data
    output_data = json.loads(result.output)
    assert output_data == mock_models_info.return_value.models.info.return_value

    # Verify the client was called correctly with env var values
    mock_models_info.assert_called_once_with(
        base_url="http://localhost:4000", api_key="sk-test"
    )
    mock_models_info.return_value.models.info.assert_called_once()


def test_models_info_table_format(mock_models_info, cli_runner):
    """Test the models info command with table output format"""
    # Run the command with default columns
    result = cli_runner.invoke(cli, ["models", "info"])

    # Check that the command succeeded
    assert result.exit_code == 0

    # Verify the output contains expected table elements
    assert "Public Model" in result.output
    assert "Upstream Model" in result.output
    assert "Updated At" in result.output
    assert "gpt-4" in result.output
    assert "2025-04-29 21:31" in result.output

    # Verify seconds and microseconds are not shown
    assert "21:31:43" not in result.output
    assert "843000" not in result.output

    # Verify the client was called correctly with env var values
    mock_models_info.assert_called_once_with(
        base_url="http://localhost:4000", api_key="sk-test"
    )
    mock_models_info.return_value.models.info.assert_called_once()


def test_models_import_only_models_matching_regex(tmp_path, mock_client, cli_runner):
    """Test the --only-models-matching-regex option for models import command"""
    # Prepare a YAML file with a mix of models
    yaml_content = {
        "model_list": [
            {
                "model_name": "gpt-4-model",
                "litellm_params": {"model": "gpt-4"},
                "model_info": {"id": "id-1"},
            },
            {
                "model_name": "gpt-3.5-model",
                "litellm_params": {"model": "gpt-3.5-turbo"},
                "model_info": {"id": "id-2"},
            },
            {
                "model_name": "llama2-model",
                "litellm_params": {"model": "llama2"},
                "model_info": {"id": "id-3"},
            },
            {
                "model_name": "other-model",
                "litellm_params": {"model": "other"},
                "model_info": {"id": "id-4"},
            },
        ]
    }
    import yaml as pyyaml

    yaml_file = tmp_path / "models.yaml"
    with open(yaml_file, "w") as f:
        pyyaml.safe_dump(yaml_content, f)

    # Patch client.models.new to track calls
    mock_new = mock_client.return_value.models.new

    # Only match models containing 'gpt' in their litellm_params.model
    result = cli_runner.invoke(
        cli, ["models", "import", str(yaml_file), "--only-models-matching-regex", "gpt"]
    )

    # Should succeed
    assert result.exit_code == 0
    # Only the two gpt models should be imported
    calls = [call.kwargs["model_params"]["model"] for call in mock_new.call_args_list]
    assert set(calls) == {"gpt-4", "gpt-3.5-turbo"}
    # Should not include llama2 or other
    assert "llama2" not in calls
    assert "other" not in calls
    # Output summary should mention the correct providers
    assert "gpt-4".split("-")[0] in result.output or "gpt" in result.output


def test_models_import_only_access_groups_matching_regex(
    tmp_path, mock_client, cli_runner
):
    """Test the --only-access-groups-matching-regex option for models import command"""
    # Prepare a YAML file with a mix of models
    yaml_content = {
        "model_list": [
            {
                "model_name": "gpt-4-model",
                "litellm_params": {"model": "gpt-4"},
                "model_info": {
                    "id": "id-1",
                    "access_groups": ["beta-models", "prod-models"],
                },
            },
            {
                "model_name": "gpt-3.5-model",
                "litellm_params": {"model": "gpt-3.5-turbo"},
                "model_info": {"id": "id-2", "access_groups": ["alpha-models"]},
            },
            {
                "model_name": "llama2-model",
                "litellm_params": {"model": "llama2"},
                "model_info": {"id": "id-3", "access_groups": ["beta-models"]},
            },
            {
                "model_name": "other-model",
                "litellm_params": {"model": "other"},
                "model_info": {"id": "id-4", "access_groups": ["other-group"]},
            },
            {
                "model_name": "no-access-group-model",
                "litellm_params": {"model": "no-access"},
                "model_info": {"id": "id-5"},
            },
        ]
    }
    import yaml as pyyaml

    yaml_file = tmp_path / "models.yaml"
    with open(yaml_file, "w") as f:
        pyyaml.safe_dump(yaml_content, f)

    # Patch client.models.new to track calls
    mock_new = mock_client.return_value.models.new

    # Only match models with access_groups containing 'beta'
    result = cli_runner.invoke(
        cli,
        [
            "models",
            "import",
            str(yaml_file),
            "--only-access-groups-matching-regex",
            "beta",
        ],
    )

    # Should succeed
    assert result.exit_code == 0
    # Only the two models with 'beta-models' in access_groups should be imported
    calls = [call.kwargs["model_params"]["model"] for call in mock_new.call_args_list]
    assert set(calls) == {"gpt-4", "llama2"}
    # Should not include gpt-3.5, other, or no-access
    assert "gpt-3.5-turbo" not in calls
    assert "other" not in calls
    assert "no-access" not in calls
    # Output summary should mention the correct providers
    assert "gpt-4".split("-")[0] in result.output or "gpt" in result.output


@pytest.mark.parametrize(
    "input_str,expected",
    [
        (None, ""),
        ("", ""),
        ("2024-05-01T12:34:56Z", "2024-05-01 12:34"),
        ("2024-05-01T12:34:56+00:00", "2024-05-01 12:34"),
        ("2024-05-01T12:34:56.123456+00:00", "2024-05-01 12:34"),
        ("2024-05-01T12:34:56.123456Z", "2024-05-01 12:34"),
        ("2024-05-01T12:34:56-04:00", "2024-05-01 12:34"),
        ("2024-05-01", "2024-05-01 00:00"),
        ("not-a-date", "not-a-date"),
    ],
)
def test_format_iso_datetime_str(input_str, expected):
    assert format_iso_datetime_str(input_str) == expected


@pytest.mark.parametrize(
    "input_val,expected",
    [
        (None, ""),
        (1699848889, "2023-11-13 04:14"),
        (1699848889.0, "2023-11-13 04:14"),
        ("not-a-timestamp", "not-a-timestamp"),
        ([1, 2, 3], "[1, 2, 3]"),
    ],
)
def test_format_timestamp(input_val, expected, force_utc_tz):
    actual = format_timestamp(input_val)
    if actual != expected:
        print(f"input: {input_val}, expected: {expected}, actual: {actual}")
    assert actual == expected


@pytest.mark.parametrize(
    "input_val,expected",
    [
        (None, ""),
        (0, "$0.0000"),
        (0.0, "$0.0000"),
        (0.00001, "$0.0100"),
        (0.00002, "$0.0200"),
        (1, "$1000.0000"),
        (1.5, "$1500.0000"),
        ("0.00001", "$0.0100"),
        ("1.5", "$1500.0000"),
        ("not-a-number", "not-a-number"),
        (1e-10, "$0.0000"),
    ],
)
def test_format_cost_per_1k_tokens(input_val, expected):
    actual = format_cost_per_1k_tokens(input_val)
    if actual != expected:
        print(f"input: {input_val}, expected: {expected}, actual: {actual}")
    assert actual == expected
