from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, Optional, Union

import httpx
from pydantic import BaseModel

import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
from litellm.llms.bedrock.image_generation.amazon_nova_canvas_transformation import (
    AmazonNovaCanvasConfig,
)
from litellm.llms.bedrock.image_generation.amazon_stability1_transformation import (
    AmazonStabilityConfig,
)
from litellm.llms.bedrock.image_generation.amazon_stability3_transformation import (
    AmazonStability3Config,
)
from litellm.llms.bedrock.image_generation.amazon_titan_transformation import (
    AmazonTitanImageGenerationConfig,
)
from litellm.llms.custom_httpx.http_handler import (
    AsyncHTTPHandler,
    HTTPHandler,
    _get_httpx_client,
    get_async_httpx_client,
)
from litellm.types.utils import ImageResponse

from ..base_aws_llm import BaseAWSLLM
from ..common_utils import BedrockError

if TYPE_CHECKING:
    from botocore.awsrequest import AWSPreparedRequest
else:
    AWSPreparedRequest = Any


class BedrockImagePreparedRequest(BaseModel):
    """
    Internal/Helper class for preparing the request for bedrock image generation
    """

    endpoint_url: str
    prepped: AWSPreparedRequest
    body: bytes
    data: dict


BedrockImageConfigClass = Union[
    type[AmazonTitanImageGenerationConfig],
    type[AmazonNovaCanvasConfig],
    type[AmazonStability3Config],
    type[AmazonStabilityConfig],
]


class BedrockImageGeneration(BaseAWSLLM):
    """
    Bedrock Image Generation handler
    """

    @classmethod
    def get_config_class(cls, model: str | None) -> BedrockImageConfigClass:
        if AmazonTitanImageGenerationConfig._is_titan_model(model):
            return AmazonTitanImageGenerationConfig
        elif AmazonNovaCanvasConfig._is_nova_model(model):
            return AmazonNovaCanvasConfig
        elif AmazonStability3Config._is_stability_3_model(model):
            return AmazonStability3Config
        else:
            return litellm.AmazonStabilityConfig

    def image_generation(
        self,
        model: str,
        prompt: str,
        model_response: ImageResponse,
        optional_params: dict,
        logging_obj: LitellmLogging,
        timeout: Optional[Union[float, httpx.Timeout]],
        aimg_generation: bool = False,
        api_base: Optional[str] = None,
        extra_headers: Optional[dict] = None,
        client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
        api_key: Optional[str] = None,
    ):
        prepared_request = self._prepare_request(
            model=model,
            optional_params=optional_params,
            api_base=api_base,
            extra_headers=extra_headers,
            logging_obj=logging_obj,
            prompt=prompt,
            api_key=api_key,
        )

        if aimg_generation is True:
            return self.async_image_generation(
                prepared_request=prepared_request,
                timeout=timeout,
                model=model,
                logging_obj=logging_obj,
                prompt=prompt,
                model_response=model_response,
                client=(
                    client
                    if client is not None and isinstance(client, AsyncHTTPHandler)
                    else None
                ),
            )

        if client is None or not isinstance(client, HTTPHandler):
            client = _get_httpx_client()
        try:
            response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body)  # type: ignore
            response.raise_for_status()
        except httpx.HTTPStatusError as err:
            error_code = err.response.status_code
            raise BedrockError(status_code=error_code, message=err.response.text)
        except httpx.TimeoutException:
            raise BedrockError(status_code=408, message="Timeout error occurred.")
        ### FORMAT RESPONSE TO OPENAI FORMAT ###
        model_response = self._transform_response_dict_to_openai_response(
            model_response=model_response,
            model=model,
            logging_obj=logging_obj,
            prompt=prompt,
            response=response,
            data=prepared_request.data,
        )
        return model_response

    async def async_image_generation(
        self,
        prepared_request: BedrockImagePreparedRequest,
        timeout: Optional[Union[float, httpx.Timeout]],
        model: str,
        logging_obj: LitellmLogging,
        prompt: str,
        model_response: ImageResponse,
        client: Optional[AsyncHTTPHandler] = None,
    ) -> ImageResponse:
        """
        Asynchronous handler for bedrock image generation

        Awaits the response from the bedrock image generation endpoint
        """
        async_client = client or get_async_httpx_client(
            llm_provider=litellm.LlmProviders.BEDROCK,
            params={"timeout": timeout},
        )

        try:
            response = await async_client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body)  # type: ignore
            response.raise_for_status()
        except httpx.HTTPStatusError as err:
            error_code = err.response.status_code
            raise BedrockError(status_code=error_code, message=err.response.text)
        except httpx.TimeoutException:
            raise BedrockError(status_code=408, message="Timeout error occurred.")

        ### FORMAT RESPONSE TO OPENAI FORMAT ###
        model_response = self._transform_response_dict_to_openai_response(
            model=model,
            logging_obj=logging_obj,
            prompt=prompt,
            response=response,
            data=prepared_request.data,
            model_response=model_response,
        )
        return model_response

    def _extract_headers_from_optional_params(self, optional_params: dict) -> dict:
        """
        Extract guardrail parameters from optional_params and convert them to headers.
        """
        headers = {}
        guardrail_identifier = optional_params.pop("guardrailIdentifier", None)
        guardrail_version = optional_params.pop("guardrailVersion", None)
        
        if guardrail_identifier is not None:
            headers["x-amz-bedrock-guardrail-identifier"] = guardrail_identifier
        if guardrail_version is not None:
            headers["x-amz-bedrock-guardrail-version"] = guardrail_version
            
        return headers

    def _prepare_request(
        self,
        model: str,
        optional_params: dict,
        api_base: Optional[str],
        extra_headers: Optional[dict],
        logging_obj: LitellmLogging,
        prompt: str,
        api_key: Optional[str],
    ) -> BedrockImagePreparedRequest:
        """
        Prepare the request body, headers, and endpoint URL for the Bedrock Image Generation API

        Args:
            model (str): The model to use for the image generation
            optional_params (dict): The optional parameters for the image generation
            api_base (Optional[str]): The base URL for the Bedrock API
            extra_headers (Optional[dict]): The extra headers to include in the request
            logging_obj (LitellmLogging): The logging object to use for logging
            prompt (str): The prompt to use for the image generation
        Returns:
            BedrockImagePreparedRequest: The prepared request object

        The BedrockImagePreparedRequest contains:
            endpoint_url (str): The endpoint URL for the Bedrock Image Generation API
            prepped (httpx.Request): The prepared request object
            body (bytes): The request body
        """
        boto3_credentials_info = self._get_boto_credentials_from_optional_params(
            optional_params, model
        )

        # Use the existing ARN-aware provider detection method
        bedrock_provider = self.get_bedrock_invoke_provider(model)
        ### SET RUNTIME ENDPOINT ###
        modelId = self.get_bedrock_model_id(
            model=model,
            provider=bedrock_provider,
            optional_params=optional_params,
        )
        _, proxy_endpoint_url = self.get_runtime_endpoint(
            api_base=api_base,
            aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
            aws_region_name=boto3_credentials_info.aws_region_name,
        )
        proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
        data = self._get_request_body(
            model=model,
            prompt=prompt,
            optional_params=optional_params,
        )

        # Make POST Request
        body = json.dumps(data).encode("utf-8")
        headers = {"Content-Type": "application/json"}
        if extra_headers is not None:
            headers = {"Content-Type": "application/json", **extra_headers}

        # Extract guardrail parameters and add them as headers
        guardrail_headers = self._extract_headers_from_optional_params(optional_params)
        headers.update(guardrail_headers)

        prepped = self.get_request_headers(
            credentials=boto3_credentials_info.credentials,
            aws_region_name=boto3_credentials_info.aws_region_name,
            extra_headers=extra_headers,
            endpoint_url=proxy_endpoint_url,
            data=body,
            headers=headers,
            api_key=api_key,
        )

        ## LOGGING
        logging_obj.pre_call(
            input=prompt,
            api_key="",
            additional_args={
                "complete_input_dict": data,
                "api_base": proxy_endpoint_url,
                "headers": prepped.headers,
            },
        )
        return BedrockImagePreparedRequest(
            endpoint_url=proxy_endpoint_url,
            prepped=prepped,
            body=body,
            data=data,
        )

    def _get_request_body(
        self,
        model: str,
        prompt: str,
        optional_params: dict,
    ) -> dict:
        """
        Get the request body for the Bedrock Image Generation API

        Checks the model/provider and transforms the request body accordingly

        Returns:
            dict: The request body to use for the Bedrock Image Generation API
        """
        config_class = self.get_config_class(model=model)
        request_body = config_class.transform_request_body(text=prompt, optional_params=optional_params)
        return dict(request_body)

    def _transform_response_dict_to_openai_response(
        self,
        model_response: ImageResponse,
        model: str,
        logging_obj: LitellmLogging,
        prompt: str,
        response: httpx.Response,
        data: dict,
    ) -> ImageResponse:
        """
        Transforms the Image Generation response from Bedrock to OpenAI format
        """

        ## LOGGING
        if logging_obj is not None:
            logging_obj.post_call(
                input=prompt,
                api_key="",
                original_response=response.text,
                additional_args={"complete_input_dict": data},
            )
        verbose_logger.debug("raw model_response: %s", response.text)
        response_dict = response.json()
        if response_dict is None:
            raise ValueError("Error in response object format, got None")

        config_class = self.get_config_class(model=model)

        config_class.transform_response_dict_to_openai_response(
            model_response=model_response,
            response_dict=response_dict,
        )

        return model_response
