"""
Translate from OpenAI's `/v1/embeddings` to Sagemaker's `/invoke`

In the Huggingface TGI format. 
"""

from typing import TYPE_CHECKING, Any, List, Optional, Union

if TYPE_CHECKING:
    from litellm.types.llms.openai import AllEmbeddingInputValues

from httpx._models import Headers, Response

from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.utils import Usage, EmbeddingResponse
from litellm.llms.voyage.embedding.transformation import VoyageEmbeddingConfig

from ..common_utils import SagemakerError


class SagemakerEmbeddingConfig(BaseEmbeddingConfig):
    """
    SageMaker embedding configuration factory for supporting embedding parameters
    """
    
    def __init__(self) -> None:
        pass

    @classmethod
    def get_model_config(cls, model: str) -> "BaseEmbeddingConfig":
        """
        Factory method to get the appropriate embedding config based on model type
        
        Args:
            model: The model name
            
        Returns:
            Appropriate embedding config instance
        """
        if "voyage" in model.lower():
            return VoyageEmbeddingConfig()
        else:
            return cls()

    def get_supported_openai_params(self, model: str) -> List[str]:
        # Check if this is an embedding model
        if "voyage" in model.lower():
            return VoyageEmbeddingConfig().get_supported_openai_params(model)
        else:
            return []

    def map_openai_params(
        self,
        non_default_params: dict,
        optional_params: dict,
        model: str,
        drop_params: bool,
    ) -> dict:
    
        return optional_params

    def get_error_class(
        self, error_message: str, status_code: int, headers: Union[dict, Headers]
    ) -> BaseLLMException:
        return SagemakerError(
            message=error_message, status_code=status_code, headers=headers
        )

    def transform_embedding_request(
        self,
        model: str,
        input: "AllEmbeddingInputValues",
        optional_params: dict,
        headers: dict,
    ) -> dict:
        """
        Transform embedding request for Hugging Face models on SageMaker
        """
        # HF models expect "inputs" field (plural)
        return {"inputs": input, **optional_params}

    def transform_embedding_response(
        self,
        model: str,
        raw_response: Response,
        model_response: "EmbeddingResponse",
        logging_obj: Any,
        api_key: Optional[str] = None,
        request_data: dict = {},
        optional_params: dict = {},
        litellm_params: dict = {},
    ) -> "EmbeddingResponse":
        """
        Transform embedding response for Hugging Face models on SageMaker
        """
        try:
            response_data = raw_response.json()
        except Exception as e:
            raise SagemakerError(
                message=f"Failed to parse response: {str(e)}", 
                status_code=raw_response.status_code
            )

        # Handle both raw array format (TEI) and wrapped format (standard HF)
        if isinstance(response_data, list):
            # TEI and some HF models return raw embedding arrays directly
            embeddings = response_data
        elif isinstance(response_data, dict) and "embedding" in response_data:
            # Standard HF format with "embedding" key
            embeddings = response_data["embedding"]
        else:
            raise SagemakerError(
                status_code=500,
                message=f"Unexpected response format. Expected list or dict with 'embedding' key, got: {type(response_data).__name__}",
            )

        if not isinstance(embeddings, list):
            raise SagemakerError(
                status_code=422,
                message=f"HF response not in expected format - {embeddings}",
            )

        output_data = []
        for idx, embedding in enumerate(embeddings):
            output_data.append(
                {"object": "embedding", "index": idx, "embedding": embedding}
            )

        model_response.object = "list"
        model_response.data = output_data
        model_response.model = model

        # Calculate usage from request data
        input_texts = request_data.get("inputs", [])
        input_tokens = 0
        for text in input_texts:
            input_tokens += len(text.split())  # Simple word count fallback

        model_response.usage = Usage(
            prompt_tokens=input_tokens,
            completion_tokens=0,
            total_tokens=input_tokens,
        )

        return model_response

    def validate_environment(
        self,
        headers: dict,
        model: str,
        messages: List[Any],
        optional_params: dict,
        litellm_params: dict,
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
    ) -> dict:
        """
        Validate environment for SageMaker embeddings
        """
        return {"Content-Type": "application/json"}
