"""
This module is used to pass through requests to the LLM APIs.
"""

import asyncio
import contextvars
from functools import partial
from typing import (
    TYPE_CHECKING,
    Any,
    AsyncGenerator,
    Coroutine,
    Generator,
    List,
    Optional,
    Union,
    cast,
)

import httpx
from httpx._types import CookieTypes, QueryParamTypes, RequestFiles

import litellm
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.passthrough.utils import CommonUtils
from litellm.utils import client

base_llm_http_handler = BaseLLMHTTPHandler()
from .utils import BasePassthroughUtils

if TYPE_CHECKING:
    from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
    from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig


@client
async def allm_passthrough_route(
    *,
    method: str,
    endpoint: str,
    model: str,
    custom_llm_provider: Optional[str] = None,
    api_base: Optional[str] = None,
    api_key: Optional[str] = None,
    request_query_params: Optional[dict] = None,
    request_headers: Optional[dict] = None,
    content: Optional[Any] = None,
    data: Optional[dict] = None,
    files: Optional[RequestFiles] = None,
    json: Optional[Any] = None,
    params: Optional[QueryParamTypes] = None,
    cookies: Optional[CookieTypes] = None,
    client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
    **kwargs,
) -> Union[httpx.Response, AsyncGenerator[Any, Any]]:
    """
    Async: Reranks a list of documents based on their relevance to the query
    """
    try:
        loop = asyncio.get_event_loop()
        kwargs["allm_passthrough_route"] = True

        model, custom_llm_provider, api_key, api_base = get_llm_provider(
            model=model,
            custom_llm_provider=custom_llm_provider,
            api_base=api_base,
            api_key=api_key,
        )

        from litellm.types.utils import LlmProviders
        from litellm.utils import ProviderConfigManager

        provider_config = cast(
            Optional["BasePassthroughConfig"], kwargs.get("provider_config")
        ) or ProviderConfigManager.get_provider_passthrough_config(
            provider=LlmProviders(custom_llm_provider),
            model=model,
        )

        if provider_config is None:
            raise Exception(f"Provider {custom_llm_provider} not found")

        func = partial(
            llm_passthrough_route,
            method=method,
            endpoint=endpoint,
            model=model,
            custom_llm_provider=custom_llm_provider,
            api_base=api_base,
            api_key=api_key,
            request_query_params=request_query_params,
            request_headers=request_headers,
            content=content,
            data=data,
            files=files,
            json=json,
            params=params,
            cookies=cookies,
            client=client,
            **kwargs,
        )

        ctx = contextvars.copy_context()
        func_with_context = partial(ctx.run, func)
        init_response = await loop.run_in_executor(None, func_with_context)

        # Since allm_passthrough_route=True, we always get a coroutine from _async_passthrough_request
        if asyncio.iscoroutine(init_response):
            response = await init_response

            # Only call raise_for_status if it's a Response object (not a generator)
            if isinstance(response, httpx.Response):
                response.raise_for_status()
            
            return response
        else:
            # This shouldn't happen when allm_passthrough_route=True, but handle it for type safety
            raise Exception("Expected coroutine from async passthrough route")

    except httpx.HTTPStatusError as e:
        # For HTTP errors, re-raise as-is to preserve the original error details
        # The caller (e.g., proxy layer) can handle conversion to appropriate response format
        raise e
    except Exception as e:
        # For other exceptions, use provider-specific error handling
        from litellm.types.utils import LlmProviders
        from litellm.utils import ProviderConfigManager

        # Get the provider using the same logic as llm_passthrough_route
        _, resolved_custom_llm_provider, _, _ = get_llm_provider(
            model=model,
            custom_llm_provider=custom_llm_provider,
            api_base=api_base,
            api_key=api_key,
        )

        # Get provider config if available
        provider_config = None
        if resolved_custom_llm_provider:
            try:
                provider_config = cast(
                    Optional["BasePassthroughConfig"], kwargs.get("provider_config")
                ) or ProviderConfigManager.get_provider_passthrough_config(
                    provider=LlmProviders(resolved_custom_llm_provider),
                    model=model,
                )
            except Exception:
                # If we can't get provider config, pass None
                pass

        if provider_config is None:
            # If no provider config available, raise the original exception
            raise e

        raise base_llm_http_handler._handle_error(
            e=e,
            provider_config=provider_config,
        )


@client
def llm_passthrough_route(
    *,
    method: str,
    endpoint: str,
    model: str,
    custom_llm_provider: Optional[str] = None,
    api_base: Optional[str] = None,
    api_key: Optional[str] = None,
    request_query_params: Optional[dict] = None,
    request_headers: Optional[dict] = None,
    allm_passthrough_route: bool = False,
    content: Optional[Any] = None,
    data: Optional[dict] = None,
    files: Optional[RequestFiles] = None,
    json: Optional[Any] = None,
    params: Optional[QueryParamTypes] = None,
    cookies: Optional[CookieTypes] = None,
    client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
    **kwargs,
) -> Union[
    httpx.Response,
    Coroutine[Any, Any, httpx.Response],
    Coroutine[Any, Any, Union[httpx.Response, AsyncGenerator[Any, Any]]],
    Generator[Any, Any, Any],
    AsyncGenerator[Any, Any],
]:
    """
    Pass through requests to the LLM APIs.

    Step 1. Build the request
    Step 2. Send the request
    Step 3. Return the response
    """
    from litellm.litellm_core_utils.get_litellm_params import get_litellm_params
    from litellm.types.utils import LlmProviders
    from litellm.utils import ProviderConfigManager

    _is_async = allm_passthrough_route

    if client is None:
        if _is_async:
            client = litellm.module_level_aclient
        else:
            client = litellm.module_level_client

    litellm_logging_obj = cast("LiteLLMLoggingObj", kwargs.get("litellm_logging_obj"))

    model, custom_llm_provider, api_key, api_base = get_llm_provider(
        model=model,
        custom_llm_provider=custom_llm_provider,
        api_base=api_base,
        api_key=api_key,
    )

    litellm_params_dict = get_litellm_params(**kwargs)
    
    # Add model_id to litellm_params if present in kwargs (for Bedrock Application Inference Profiles)
    if "model_id" in kwargs:
        litellm_params_dict["model_id"] = kwargs["model_id"]
    
    litellm_logging_obj.update_environment_variables(
        model=model,
        litellm_params=litellm_params_dict,
        optional_params={},
        endpoint=endpoint,
        custom_llm_provider=custom_llm_provider,
        request_data=data if data else json,
    )

    provider_config = cast(
        Optional["BasePassthroughConfig"], kwargs.get("provider_config")
    ) or ProviderConfigManager.get_provider_passthrough_config(
        provider=LlmProviders(custom_llm_provider),
        model=model,
    )
    if provider_config is None:
        raise Exception(f"Provider {custom_llm_provider} not found")

    updated_url, base_target_url = provider_config.get_complete_url(
        api_base=api_base,
        api_key=api_key,
        model=model,
        endpoint=endpoint,
        request_query_params=request_query_params,
        litellm_params=litellm_params_dict,
    )

    # [TODO: Refactor to bedrockpassthroughconfig] need to encode the id of application-inference-profile for bedrock
    if custom_llm_provider == "bedrock" and "application-inference-profile" in endpoint:
        encoded_url_str = CommonUtils.encode_bedrock_runtime_modelid_arn(
            str(updated_url)
        )
        updated_url = httpx.URL(encoded_url_str)

    # Add or update query parameters
    provider_api_key = provider_config.get_api_key(api_key)

    auth_headers = provider_config.validate_environment(
        headers={},
        model=model,
        messages=[],
        optional_params={},
        litellm_params=litellm_params_dict,
        api_key=provider_api_key,
        api_base=base_target_url,
    )

    headers = BasePassthroughUtils.forward_headers_from_request(
        request_headers=request_headers or {},
        headers=auth_headers,
        forward_headers=False,
    )

    headers, signed_json_body = provider_config.sign_request(
        headers=headers,
        litellm_params=litellm_params_dict,
        request_data=data if data else json,
        api_base=str(updated_url),
        model=model,
    )

    ## SWAP MODEL IN JSON BODY [TODO: REFACTOR TO A provider_config.transform_request method]
    if json and isinstance(json, dict) and "model" in json:
        json["model"] = model

    request = client.client.build_request(
        method=method,
        url=updated_url,
        content=signed_json_body if signed_json_body is not None else content,
        data=data if (signed_json_body is None and content is None) else None,
        files=files,
        json=json if (signed_json_body is None and content is None) else None,
        params=params,
        headers=headers,
        cookies=cookies,
    )

    ## IS STREAMING REQUEST
    is_streaming_request = provider_config.is_streaming_request(
        endpoint=endpoint,
        request_data=data or json or {},
    )

    # Update logging object with streaming status
    litellm_logging_obj.stream = is_streaming_request

    ## LOGGING PRE-CALL
    request_data = data if data else json
    litellm_logging_obj.pre_call(
        input=request_data,
        api_key=provider_api_key,
        additional_args={
            "complete_input_dict": request_data,
            "api_base": str(updated_url),
            "headers": headers,
        },
    )

    try:
        if _is_async:
            # Return the coroutine to be awaited by the caller
            return _async_passthrough_request(
                client=client,
                request=request,
                is_streaming_request=is_streaming_request,
                litellm_logging_obj=litellm_logging_obj,
                provider_config=provider_config,
            )
        else:
            # Sync path - client.client.send returns Response directly
            response: httpx.Response = client.client.send(request=request, stream=is_streaming_request)  # type: ignore
            response.raise_for_status()

            if (
                hasattr(response, "iter_bytes") and is_streaming_request
            ):  # yield the chunk, so we can store it in the logging object
                return _sync_streaming(response, litellm_logging_obj, provider_config)
            else:
                # For non-streaming responses, yield the entire response
                return response
    except Exception as e:
        if provider_config is None:
            raise e
        raise base_llm_http_handler._handle_error(
            e=e,
            provider_config=provider_config,
        )


async def _async_passthrough_request(
    client: Union[HTTPHandler, AsyncHTTPHandler],
    request: httpx.Request,
    is_streaming_request: bool,
    litellm_logging_obj: "LiteLLMLoggingObj",
    provider_config: "BasePassthroughConfig",
) -> Union[httpx.Response, AsyncGenerator[Any, Any]]:
    """
    Handle async passthrough requests.
    Uses async client to send request and properly handles streaming.
    """
    # client.client.send returns a coroutine for async clients
    response_result = client.client.send(request=request, stream=is_streaming_request)
    
    # Check if it's a coroutine and await it
    if asyncio.iscoroutine(response_result):
        if is_streaming_request:
            # Pass the coroutine to _async_streaming which will await it
            return _async_streaming(
                response=response_result,
                litellm_logging_obj=litellm_logging_obj,
                provider_config=provider_config,
            )
        else:
            response = await response_result
            await response.aread()
            response.raise_for_status()
            return response
    else:
        # Fallback for sync-like behavior (shouldn't happen in async path)
        raise Exception("Expected coroutine from async client")


def _sync_streaming(
    response: httpx.Response,
    litellm_logging_obj: "LiteLLMLoggingObj",
    provider_config: "BasePassthroughConfig",
):
    from litellm.utils import executor

    try:
        raw_bytes: List[bytes] = []
        for chunk in response.iter_bytes():  # type: ignore
            raw_bytes.append(chunk)
            yield chunk

        executor.submit(
            litellm_logging_obj.flush_passthrough_collected_chunks,
            raw_bytes=raw_bytes,
            provider_config=provider_config,
        )
    except Exception as e:
        raise e


async def _async_streaming(
    response: Coroutine[Any, Any, httpx.Response],
    litellm_logging_obj: "LiteLLMLoggingObj",
    provider_config: "BasePassthroughConfig",
):
    iter_response = await response
    try:
        iter_response.raise_for_status()
        raw_bytes: List[bytes] = []

        async for chunk in iter_response.aiter_bytes():  # type: ignore

            raw_bytes.append(chunk)
            yield chunk

        asyncio.create_task(
            litellm_logging_obj.async_flush_passthrough_collected_chunks(
                raw_bytes=raw_bytes,
                provider_config=provider_config,
            )
        )
    except Exception:
        try:
            await iter_response.aclose()
        except Exception:
            pass
        raise
