"""
A2A Protocol endpoints for LiteLLM Proxy.

Allows clients to invoke agents through LiteLLM using the A2A protocol.
The A2A SDK can point to LiteLLM's URL and invoke agents registered with LiteLLM.
"""

import json
from typing import Any, Dict, Optional

from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse

from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.agent_endpoints.utils import merge_agent_headers
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.utils import all_litellm_params

router = APIRouter()


def _jsonrpc_error(
    request_id: Optional[str],
    code: int,
    message: str,
    status_code: int = 400,
) -> JSONResponse:
    """Create a JSON-RPC 2.0 error response."""
    return JSONResponse(
        content={
            "jsonrpc": "2.0",
            "id": request_id,
            "error": {"code": code, "message": message},
        },
        status_code=status_code,
    )


def _get_agent(agent_id: str):
    """Look up an agent by ID or name. Returns None if not found."""
    from litellm.proxy.agent_endpoints.agent_registry import \
        global_agent_registry

    agent = global_agent_registry.get_agent_by_id(agent_id=agent_id)
    if agent is None:
        agent = global_agent_registry.get_agent_by_name(agent_name=agent_id)
    return agent


def _enforce_inbound_trace_id(agent: Any, request: Request) -> None:
    """Raise 400 if agent requires x-litellm-trace-id on inbound calls and it is missing."""
    agent_litellm_params = agent.litellm_params or {}
    if not agent_litellm_params.get("require_trace_id_on_calls_to_agent"):
        return

    from litellm.proxy.litellm_pre_call_utils import get_chain_id_from_headers

    headers_dict = dict(request.headers)
    trace_id = get_chain_id_from_headers(headers_dict)
    if not trace_id:
        raise HTTPException(
            status_code=400,
            detail=(
                f"Agent '{agent.agent_id}' requires x-litellm-trace-id header "
                "on all inbound requests."
            ),
        )


async def _handle_stream_message(
    api_base: Optional[str],
    request_id: str,
    params: dict,
    litellm_params: Optional[dict] = None,
    agent_id: Optional[str] = None,
    metadata: Optional[dict] = None,
    proxy_server_request: Optional[dict] = None,
    *,
    agent_extra_headers: Optional[Dict[str, str]] = None,
    user_api_key_dict: Optional[UserAPIKeyAuth] = None,
    request_data: Optional[dict] = None,
    proxy_logging_obj: Optional[Any] = None,
) -> StreamingResponse:
    """Handle message/stream method via SDK functions.

    When user_api_key_dict, request_data, and proxy_logging_obj are provided,
    uses common_request_processing.async_streaming_data_generator with NDJSON
    serializers so proxy hooks and cost injection apply.
    """
    from litellm.a2a_protocol import asend_message_streaming
    from litellm.a2a_protocol.main import A2A_SDK_AVAILABLE

    if not A2A_SDK_AVAILABLE:

        async def _error_stream():
            yield json.dumps(
                {
                    "jsonrpc": "2.0",
                    "id": request_id,
                    "error": {
                        "code": -32603,
                        "message": "Server error: 'a2a' package not installed",
                    },
                }
            ) + "\n"

        return StreamingResponse(_error_stream(), media_type="application/x-ndjson")

    from a2a.types import MessageSendParams, SendStreamingMessageRequest

    use_proxy_hooks = (
        user_api_key_dict is not None
        and request_data is not None
        and proxy_logging_obj is not None
    )

    async def stream_response():
        try:
            a2a_request = SendStreamingMessageRequest(
                id=request_id,
                params=MessageSendParams(**params),
            )
            a2a_stream = asend_message_streaming(
                request=a2a_request,
                api_base=api_base,
                litellm_params=litellm_params,
                agent_id=agent_id,
                metadata=metadata,
                proxy_server_request=proxy_server_request,
                agent_extra_headers=agent_extra_headers,
            )

            if (
                use_proxy_hooks
                and user_api_key_dict is not None
                and request_data is not None
                and proxy_logging_obj is not None
            ):
                from litellm.proxy.common_request_processing import \
                    ProxyBaseLLMRequestProcessing

                def _ndjson_chunk(chunk: Any) -> str:
                    if hasattr(chunk, "model_dump"):
                        obj = chunk.model_dump(mode="json", exclude_none=True)
                    else:
                        obj = chunk
                    return json.dumps(obj) + "\n"

                def _ndjson_error(proxy_exc: Any) -> str:
                    return (
                        json.dumps(
                            {
                                "jsonrpc": "2.0",
                                "id": request_id,
                                "error": {
                                    "code": -32603,
                                    "message": getattr(
                                        proxy_exc,
                                        "message",
                                        f"Streaming error: {proxy_exc!s}",
                                    ),
                                },
                            }
                        )
                        + "\n"
                    )

                async for (
                    line
                ) in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
                    response=a2a_stream,
                    user_api_key_dict=user_api_key_dict,
                    request_data=request_data,
                    proxy_logging_obj=proxy_logging_obj,
                    serialize_chunk=_ndjson_chunk,
                    serialize_error=_ndjson_error,
                ):
                    yield line
            else:
                async for chunk in a2a_stream:
                    if hasattr(chunk, "model_dump"):
                        yield json.dumps(
                            chunk.model_dump(mode="json", exclude_none=True)
                        ) + "\n"
                    else:
                        yield json.dumps(chunk) + "\n"
        except Exception as e:
            verbose_proxy_logger.exception(f"Error streaming A2A response: {e}")
            if (
                use_proxy_hooks
                and proxy_logging_obj is not None
                and user_api_key_dict is not None
                and request_data is not None
            ):
                transformed_exception = await proxy_logging_obj.post_call_failure_hook(
                    user_api_key_dict=user_api_key_dict,
                    original_exception=e,
                    request_data=request_data,
                )
                if transformed_exception is not None:
                    e = transformed_exception
            if isinstance(e, HTTPException):
                raise
            yield json.dumps(
                {
                    "jsonrpc": "2.0",
                    "id": request_id,
                    "error": {"code": -32603, "message": f"Streaming error: {str(e)}"},
                }
            ) + "\n"

    return StreamingResponse(stream_response(), media_type="application/x-ndjson")


@router.get(
    "/a2a/{agent_id}/.well-known/agent-card.json",
    tags=["[beta] A2A Agents"],
    dependencies=[Depends(user_api_key_auth)],
)
@router.get(
    "/a2a/{agent_id}/.well-known/agent.json",
    tags=["[beta] A2A Agents"],
    dependencies=[Depends(user_api_key_auth)],
)
async def get_agent_card(
    agent_id: str,
    request: Request,
    user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
    """
    Get the agent card for an agent (A2A discovery endpoint).

    Supports both standard paths:
    - /.well-known/agent-card.json
    - /.well-known/agent.json

    The URL in the agent card is rewritten to point to the LiteLLM proxy,
    so all subsequent A2A calls go through LiteLLM for logging and cost tracking.
    """
    from litellm.proxy.agent_endpoints.auth.agent_permission_handler import \
        AgentRequestHandler

    try:
        agent = _get_agent(agent_id)
        if agent is None:
            raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")

        # Check agent permission (skip for admin users)
        is_allowed = await AgentRequestHandler.is_agent_allowed(
            agent_id=agent.agent_id,
            user_api_key_auth=user_api_key_dict,
        )
        if not is_allowed:
            raise HTTPException(
                status_code=403,
                detail=f"Agent '{agent_id}' is not allowed for your key/team. Contact proxy admin for access.",
            )

        # Copy and rewrite URL to point to LiteLLM proxy
        agent_card = dict(agent.agent_card_params)
        agent_card["url"] = f"{str(request.base_url).rstrip('/')}/a2a/{agent_id}"

        verbose_proxy_logger.debug(
            f"Returning agent card for '{agent_id}' with proxy URL: {agent_card['url']}"
        )
        return JSONResponse(content=agent_card)

    except HTTPException:
        raise
    except Exception as e:
        verbose_proxy_logger.exception(f"Error getting agent card: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@router.post(
    "/a2a/{agent_id}",
    tags=["[beta] A2A Agents"],
    dependencies=[Depends(user_api_key_auth)],
)
@router.post(
    "/a2a/{agent_id}/message/send",
    tags=["[beta] A2A Agents"],
    dependencies=[Depends(user_api_key_auth)],
)
@router.post(
    "/v1/a2a/{agent_id}/message/send",
    tags=["[beta] A2A Agents"],
    dependencies=[Depends(user_api_key_auth)],
)
async def invoke_agent_a2a(  # noqa: PLR0915
    agent_id: str,
    request: Request,
    fastapi_response: Response,
    user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
    """
    Invoke an agent using the A2A protocol (JSON-RPC 2.0).

    Supported methods:
    - message/send: Send a message and get a response
    - message/stream: Send a message and stream the response
    """
    from litellm.a2a_protocol import asend_message
    from litellm.a2a_protocol.main import A2A_SDK_AVAILABLE
    from litellm.proxy.agent_endpoints.auth.agent_permission_handler import \
        AgentRequestHandler
    from litellm.proxy.proxy_server import (general_settings, proxy_config,
                                            proxy_logging_obj, version)

    body = {}
    try:
        body = await request.json()

        verbose_proxy_logger.debug(f"A2A request for agent '{agent_id}': {body}")

        # Validate JSON-RPC format
        if body.get("jsonrpc") != "2.0":
            return _jsonrpc_error(
                body.get("id"), -32600, "Invalid Request: jsonrpc must be '2.0'"
            )

        request_id = body.get("id")
        method = body.get("method")
        params = body.get("params", {})

        if params:
            # extract any litellm params from the params - eg. 'guardrails'
            params_to_remove = []
            for key, value in params.items():
                if key in all_litellm_params:
                    params_to_remove.append(key)
                    body[key] = value
            for key in params_to_remove:
                params.pop(key)

        if not A2A_SDK_AVAILABLE:
            return _jsonrpc_error(
                request_id,
                -32603,
                "Server error: 'a2a' package not installed. Please install 'a2a-sdk'.",
                500,
            )

        # Find the agent
        agent = _get_agent(agent_id)
        if agent is None:
            return _jsonrpc_error(
                request_id, -32000, f"Agent '{agent_id}' not found", 404
            )

        is_allowed = await AgentRequestHandler.is_agent_allowed(
            agent_id=agent.agent_id,
            user_api_key_auth=user_api_key_dict,
        )
        if not is_allowed:
            raise HTTPException(
                status_code=403,
                detail=f"Agent '{agent_id}' is not allowed for your key/team. Contact proxy admin for access.",
            )

        _enforce_inbound_trace_id(agent, request)

        # Get backend URL and agent name
        agent_url = agent.agent_card_params.get("url")
        agent_name = agent.agent_card_params.get("name", agent_id)

        # Get litellm_params (may include custom_llm_provider for completion bridge)
        litellm_params = agent.litellm_params or {}
        custom_llm_provider = litellm_params.get("custom_llm_provider")

        # URL is required unless using completion bridge with a provider that derives endpoint from model
        # (e.g., bedrock/agentcore derives endpoint from ARN in model string)
        if not agent_url and not custom_llm_provider:
            return _jsonrpc_error(
                request_id, -32000, f"Agent '{agent_id}' has no URL configured", 500
            )

        verbose_proxy_logger.info(
            f"Proxying A2A request to agent '{agent_id}' at {agent_url or 'completion-bridge'}"
        )

        # Set up data dict for litellm processing
        if "metadata" not in body:
            body["metadata"] = {}
        body["metadata"]["agent_id"] = agent.agent_id

        body.update(
            {
                "model": f"a2a_agent/{agent_name}",
                "custom_llm_provider": "a2a_agent",
            }
        )

        # Add litellm data (user_api_key, user_id, team_id, etc.)
        from litellm.proxy.common_request_processing import \
            ProxyBaseLLMRequestProcessing

        processor = ProxyBaseLLMRequestProcessing(data=body)
        data, logging_obj = await processor.common_processing_pre_call_logic(
            request=request,
            general_settings=general_settings,
            user_api_key_dict=user_api_key_dict,
            proxy_logging_obj=proxy_logging_obj,
            proxy_config=proxy_config,
            route_type="asend_message",
            version=version,
        )

        # Build merged headers for the backend agent
        static_headers: Dict[str, str] = dict(agent.static_headers or {})

        raw_headers = dict(request.headers)
        normalized = {k.lower(): v for k, v in raw_headers.items()}

        dynamic_headers: Dict[str, str] = {}

        # 1. Admin-configured extra_headers: forward named headers from client request
        if agent.extra_headers:
            for header_name in agent.extra_headers:
                val = normalized.get(header_name.lower())
                if val is not None:
                    dynamic_headers[header_name] = val

        # 2. Convention-based forwarding: x-a2a-{agent_id_or_name}-{header_name}
        #    Matches both agent_id (UUID) and agent_name (alias), case-insensitive.
        for alias in (agent.agent_id.lower(), agent.agent_name.lower()):
            prefix = f"x-a2a-{alias}-"
            for key, val in normalized.items():
                if key.startswith(prefix):
                    header_name = key[len(prefix) :]
                    if header_name:
                        dynamic_headers[header_name] = val

        agent_extra_headers = merge_agent_headers(
            dynamic_headers=dynamic_headers or None,
            static_headers=static_headers or None,
        )

        # Route through SDK functions
        if method == "message/send":
            from a2a.types import MessageSendParams, SendMessageRequest

            a2a_request = SendMessageRequest(
                id=request_id,
                params=MessageSendParams(**params),
            )
            response = await asend_message(
                request=a2a_request,
                api_base=agent_url,
                litellm_params=litellm_params,
                agent_id=agent.agent_id,
                metadata=data.get("metadata", {}),
                proxy_server_request=data.get("proxy_server_request"),
                litellm_logging_obj=logging_obj,
                agent_extra_headers=agent_extra_headers,
            )

            response = await proxy_logging_obj.post_call_success_hook(
                user_api_key_dict=user_api_key_dict,
                data=data,
                response=response,
            )
            return JSONResponse(
                content=(
                    response.model_dump(mode="json", exclude_none=True)  # type: ignore
                    if hasattr(response, "model_dump")
                    else response
                )
            )

        elif method == "message/stream":
            return await _handle_stream_message(
                api_base=agent_url,
                request_id=request_id,
                params=params,
                litellm_params=litellm_params,
                agent_id=agent.agent_id,
                metadata=data.get("metadata", {}),
                proxy_server_request=data.get("proxy_server_request"),
                agent_extra_headers=agent_extra_headers,
                user_api_key_dict=user_api_key_dict,
                request_data=data,
                proxy_logging_obj=proxy_logging_obj,
            )
        else:
            return _jsonrpc_error(request_id, -32601, f"Method '{method}' not found")

    except HTTPException:
        raise
    except Exception as e:
        verbose_proxy_logger.exception(f"Error invoking agent: {e}")
        return _jsonrpc_error(body.get("id"), -32603, f"Internal error: {str(e)}", 500)
