import pytest
import sys
import os
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch

# Add project root to sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))

from litellm.proxy.auth.user_api_key_auth import (
    _resolve_jwt_to_virtual_key,
)
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy._types import (
    JWTKeyMappingResponse,
    LiteLLM_JWTAuth,
    LitellmUserRoles,
    UserAPIKeyAuth,
)
from litellm.proxy.management_endpoints.jwt_key_mapping_endpoints import (
    _to_response,
    create_jwt_key_mapping,
    delete_jwt_key_mapping,
    info_jwt_key_mapping,
    update_jwt_key_mapping,
)
from litellm.caching.caching import DualCache
from fastapi import HTTPException


# ──────────────────────────────────────────────
# Tests: _resolve_jwt_to_virtual_key
# ──────────────────────────────────────────────


@pytest.mark.asyncio
async def test_jwt_to_virtual_key_mapping_resolution():
    """
    Test that a JWT claim is correctly resolved to a virtual key token.
    """
    jwt_handler = JWTHandler()
    jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
        virtual_key_claim_field="email", virtual_key_mapping_cache_ttl=3600
    )

    jwt_claims = {"email": "user@example.com", "sub": "123"}

    prisma_client = MagicMock()
    prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock()

    # Mock finding a mapping
    mock_mapping = MagicMock()
    mock_mapping.token = "sk-1234"
    mock_mapping.is_active = True
    prisma_client.db.litellm_jwtkeymapping.find_first.return_value = mock_mapping

    # Mock getting the key object
    mock_key_obj = UserAPIKeyAuth(token="sk-1234", team_id="team1")

    user_api_key_cache = DualCache()

    # Use patch to mock get_key_object in the module where it's used
    with patch(
        "litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
    ) as mock_get_key:
        mock_get_key.return_value = mock_key_obj

        result = await _resolve_jwt_to_virtual_key(
            jwt_claims=jwt_claims,
            jwt_handler=jwt_handler,
            prisma_client=prisma_client,
            user_api_key_cache=user_api_key_cache,
            parent_otel_span=None,
            proxy_logging_obj=None,
        )

        assert result == mock_key_obj
        prisma_client.db.litellm_jwtkeymapping.find_first.assert_called_once()

        # Test Cache hit
        prisma_client.db.litellm_jwtkeymapping.find_first.reset_mock()
        result_cached = await _resolve_jwt_to_virtual_key(
            jwt_claims=jwt_claims,
            jwt_handler=jwt_handler,
            prisma_client=prisma_client,
            user_api_key_cache=user_api_key_cache,
            parent_otel_span=None,
            proxy_logging_obj=None,
        )
        assert result_cached == mock_key_obj
        prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()


@pytest.mark.asyncio
async def test_jwt_to_virtual_key_mapping_no_mapping():
    """
    Test that when no mapping exists, resolve returns None.
    """
    jwt_handler = JWTHandler()
    jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(virtual_key_claim_field="email")
    jwt_claims = {"email": "unknown@example.com"}

    prisma_client = MagicMock()
    prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock()
    prisma_client.db.litellm_jwtkeymapping.find_first.return_value = None

    # Mock get_key_object just in case
    with patch(
        "litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
    ):
        user_api_key_cache = DualCache()

        result = await _resolve_jwt_to_virtual_key(
            jwt_claims=jwt_claims,
            jwt_handler=jwt_handler,
            prisma_client=prisma_client,
            user_api_key_cache=user_api_key_cache,
            parent_otel_span=None,
            proxy_logging_obj=None,
        )

        assert result is None

        # Test Negative Cache hit
        prisma_client.db.litellm_jwtkeymapping.find_first.reset_mock()
        result_cached = await _resolve_jwt_to_virtual_key(
            jwt_claims=jwt_claims,
            jwt_handler=jwt_handler,
            prisma_client=prisma_client,
            user_api_key_cache=user_api_key_cache,
            parent_otel_span=None,
            proxy_logging_obj=None,
        )
        assert result_cached is None
        prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()


# ──────────────────────────────────────────────
# Tests: _to_response redacts hashed token
# ──────────────────────────────────────────────


def test_to_response_excludes_token():
    """_to_response should not expose the hashed token field."""
    now = datetime.now(timezone.utc)
    mock_mapping = MagicMock()
    mock_mapping.id = "mapping-1"
    mock_mapping.jwt_claim_name = "email"
    mock_mapping.jwt_claim_value = "user@example.com"
    mock_mapping.token = "hashed_secret_value"
    mock_mapping.description = "test"
    mock_mapping.is_active = True
    mock_mapping.created_at = now
    mock_mapping.updated_at = now
    mock_mapping.created_by = "admin"
    mock_mapping.updated_by = "admin"

    resp = _to_response(mock_mapping)

    assert isinstance(resp, JWTKeyMappingResponse)
    assert resp.id == "mapping-1"
    assert resp.jwt_claim_name == "email"
    assert "token" not in resp.model_fields


# ──────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────


def _make_admin_auth() -> UserAPIKeyAuth:
    return UserAPIKeyAuth(
        token="sk-admin",
        user_role=LitellmUserRoles.PROXY_ADMIN,
    )


def _make_non_admin_auth() -> UserAPIKeyAuth:
    return UserAPIKeyAuth(
        token="sk-user",
        user_role=LitellmUserRoles.INTERNAL_USER,
    )


def _mock_prisma():
    prisma = MagicMock()
    prisma.db.litellm_jwtkeymapping.create = AsyncMock()
    prisma.db.litellm_jwtkeymapping.find_unique = AsyncMock()
    prisma.db.litellm_jwtkeymapping.find_many = AsyncMock()
    prisma.db.litellm_jwtkeymapping.update = AsyncMock()
    prisma.db.litellm_jwtkeymapping.delete = AsyncMock()
    prisma.db.litellm_jwtkeymapping.count = AsyncMock(return_value=0)
    return prisma


def _mock_mapping(
    id="mapping-1",
    claim_name="email",
    claim_value="user@example.com",
):
    now = datetime.now(timezone.utc)
    m = MagicMock()
    m.id = id
    m.jwt_claim_name = claim_name
    m.jwt_claim_value = claim_value
    m.token = "hashed_token"
    m.description = None
    m.is_active = True
    m.created_at = now
    m.updated_at = now
    m.created_by = "admin"
    m.updated_by = "admin"
    return m


# ──────────────────────────────────────────────
# Tests: CRUD endpoint error handling
# ──────────────────────────────────────────────


@pytest.mark.asyncio
async def test_create_returns_409_on_unique_violation():
    """Duplicate mapping should return 409, not 500."""
    from litellm.proxy._types import CreateJWTKeyMappingRequest

    mock_prisma = _mock_prisma()
    mock_prisma.db.litellm_jwtkeymapping.create.side_effect = Exception(
        "Unique constraint failed (P2002)"
    )
    mock_cache = AsyncMock()

    data = CreateJWTKeyMappingRequest(
        jwt_claim_name="email", jwt_claim_value="user@example.com", key="sk-test-key",
    )

    with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), patch(
        "litellm.proxy.proxy_server.user_api_key_cache", mock_cache
    ):
        with pytest.raises(HTTPException) as exc_info:
            await create_jwt_key_mapping(data=data, user_api_key_dict=_make_admin_auth())
        assert exc_info.value.status_code == 409
        assert "already exists" in exc_info.value.detail


@pytest.mark.asyncio
async def test_create_returns_400_on_foreign_key_violation():
    """Non-existent key should return 400, not 500."""
    from litellm.proxy._types import CreateJWTKeyMappingRequest

    mock_prisma = _mock_prisma()
    mock_prisma.db.litellm_jwtkeymapping.create.side_effect = Exception(
        "Foreign key constraint failed on field: `token` (P2003)"
    )
    mock_cache = AsyncMock()

    data = CreateJWTKeyMappingRequest(
        jwt_claim_name="sub", jwt_claim_value="user-999", key="sk-nonexistent",
    )

    with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), patch(
        "litellm.proxy.proxy_server.user_api_key_cache", mock_cache
    ):
        with pytest.raises(HTTPException) as exc_info:
            await create_jwt_key_mapping(data=data, user_api_key_dict=_make_admin_auth())
        assert exc_info.value.status_code == 400
        assert "does not match" in exc_info.value.detail


@pytest.mark.asyncio
async def test_create_non_admin_returns_403():
    """Non-admin users should get 403."""
    from litellm.proxy._types import CreateJWTKeyMappingRequest

    data = CreateJWTKeyMappingRequest(
        jwt_claim_name="email", jwt_claim_value="user@example.com", key="sk-test",
    )

    with pytest.raises(HTTPException) as exc_info:
        await create_jwt_key_mapping(data=data, user_api_key_dict=_make_non_admin_auth())
    assert exc_info.value.status_code == 403


@pytest.mark.asyncio
async def test_delete_returns_404_when_not_found():
    """Deleting non-existent mapping should return 404."""
    from litellm.proxy._types import DeleteJWTKeyMappingRequest

    mock_prisma = _mock_prisma()
    mock_prisma.db.litellm_jwtkeymapping.find_unique.return_value = None
    mock_cache = AsyncMock()

    data = DeleteJWTKeyMappingRequest(id="nonexistent-id")

    with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), patch(
        "litellm.proxy.proxy_server.user_api_key_cache", mock_cache
    ):
        with pytest.raises(HTTPException) as exc_info:
            await delete_jwt_key_mapping(data=data, user_api_key_dict=_make_admin_auth())
        assert exc_info.value.status_code == 404


@pytest.mark.asyncio
async def test_update_returns_404_when_not_found():
    """Updating non-existent mapping should return 404."""
    from litellm.proxy._types import UpdateJWTKeyMappingRequest

    mock_prisma = _mock_prisma()
    mock_prisma.db.litellm_jwtkeymapping.find_unique.return_value = None
    mock_cache = AsyncMock()

    data = UpdateJWTKeyMappingRequest(id="nonexistent-id", description="test")

    with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), patch(
        "litellm.proxy.proxy_server.user_api_key_cache", mock_cache
    ):
        with pytest.raises(HTTPException) as exc_info:
            await update_jwt_key_mapping(data=data, user_api_key_dict=_make_admin_auth())
        assert exc_info.value.status_code == 404


@pytest.mark.asyncio
async def test_info_returns_404_when_not_found():
    """Getting info for non-existent mapping should return 404."""
    mock_prisma = _mock_prisma()
    mock_prisma.db.litellm_jwtkeymapping.find_unique.return_value = None

    with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
        with pytest.raises(HTTPException) as exc_info:
            await info_jwt_key_mapping(id="nonexistent-id", user_api_key_dict=_make_admin_auth())
        assert exc_info.value.status_code == 404


@pytest.mark.asyncio
async def test_create_success_returns_response_without_token():
    """Successful create should return JWTKeyMappingResponse without hashed token."""
    from litellm.proxy._types import CreateJWTKeyMappingRequest

    mock_prisma = _mock_prisma()
    mock_prisma.db.litellm_jwtkeymapping.create.return_value = _mock_mapping()
    mock_cache = AsyncMock()

    data = CreateJWTKeyMappingRequest(
        jwt_claim_name="email", jwt_claim_value="user@example.com", key="sk-test-key",
    )

    with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), patch(
        "litellm.proxy.proxy_server.user_api_key_cache", mock_cache
    ):
        result = await create_jwt_key_mapping(data=data, user_api_key_dict=_make_admin_auth())
        assert isinstance(result, JWTKeyMappingResponse)
        assert "token" not in result.model_fields
        assert result.jwt_claim_name == "email"
