from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast, get_type_hints

import httpx
from openai.types.responses import ResponseReasoningItem
from pydantic import BaseModel, ValidationError

import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import process_response_headers
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
    _safe_convert_created_field,
)
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import *
from litellm.types.responses.main import *
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LlmProviders

from ..common_utils import OpenAIError

if TYPE_CHECKING:
    from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj

    LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
    LiteLLMLoggingObj = Any


class OpenAIResponsesAPIConfig(BaseResponsesAPIConfig):
    @property
    def custom_llm_provider(self) -> LlmProviders:
        return LlmProviders.OPENAI

    def get_supported_openai_params(self, model: str) -> list:
        """
        All OpenAI Responses API params are supported
        """
        supported_params = get_type_hints(ResponsesAPIRequestParams).keys()
        return list(
            set(
                [
                    "input",
                    "model",
                    "extra_headers",
                    "extra_query",
                    "extra_body",
                    "timeout",
                ]
                + list(supported_params)
            )
        )

    def map_openai_params(
        self,
        response_api_optional_params: ResponsesAPIOptionalRequestParams,
        model: str,
        drop_params: bool,
    ) -> Dict:
        """No mapping applied since inputs are in OpenAI spec already"""
        return dict(response_api_optional_params)

    def transform_responses_api_request(
        self,
        model: str,
        input: Union[str, ResponseInputParam],
        response_api_optional_request_params: Dict,
        litellm_params: GenericLiteLLMParams,
        headers: dict,
    ) -> Dict:
        """No transform applied since inputs are in OpenAI spec already"""

        input = self._validate_input_param(input)
        final_request_params = dict(
            ResponsesAPIRequestParams(
                model=model, input=input, **response_api_optional_request_params
            )
        )

        return final_request_params

    def _validate_input_param(
        self, input: Union[str, ResponseInputParam]
    ) -> Union[str, ResponseInputParam]:
        """
        Ensure all input fields if pydantic are converted to dict

        OpenAI API Fails when we try to JSON dumps specific input pydantic fields.
        This function ensures all input fields are converted to dict.
        """
        if isinstance(input, list):
            validated_input = []
            for item in input:
                # if it's pydantic, convert to dict
                if isinstance(item, BaseModel):
                    validated_input.append(item.model_dump(exclude_none=True))
                elif isinstance(item, dict):
                    # Handle reasoning items specifically to filter out status=None
                    if item.get("type") == "reasoning":
                        verbose_logger.debug(f"Handling reasoning item: {item}")
                        # Type assertion since we know it's a dict at this point
                        dict_item = cast(Dict[str, Any], item)
                        filtered_item = self._handle_reasoning_item(dict_item)
                    else:
                        # For other dict items, just pass through
                        filtered_item = cast(Dict[str, Any], item)
                    validated_input.append(filtered_item)
                else:
                    validated_input.append(item)
            return validated_input  # type: ignore
        # Input is expected to be either str or List, no single BaseModel expected
        return input

    def _handle_reasoning_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
        """
        Handle reasoning items specifically to filter out status=None using OpenAI's model.
        Issue: https://github.com/BerriAI/litellm/issues/13484
        OpenAI API does not accept ReasoningItem(status=None), so we need to:
        1. Check if the item is a reasoning type
        2. Create a ResponseReasoningItem object with the item data
        3. Convert it back to dict with exclude_none=True to filter None values
        """
        if item.get("type") == "reasoning":
            try:
                # Ensure required fields are present for ResponseReasoningItem
                item_data = dict(item)
                if "summary" not in item_data:
                    item_data["summary"] = (
                        item_data.get("reasoning_content", "")[:100] + "..."
                        if len(item_data.get("reasoning_content", "")) > 100
                        else item_data.get("reasoning_content", "")
                    )

                # Create ResponseReasoningItem object from the item data
                reasoning_item = ResponseReasoningItem(**item_data)

                # Convert back to dict with exclude_none=True to exclude None fields
                dict_reasoning_item = reasoning_item.model_dump(exclude_none=True)

                return dict_reasoning_item
            except Exception as e:
                verbose_logger.debug(
                    f"Failed to create ResponseReasoningItem, falling back to manual filtering: {e}"
                )
                # Fallback: manually filter out known None fields
                filtered_item = {
                    k: v
                    for k, v in item.items()
                    if v is not None
                    or k not in {"status", "content", "encrypted_content"}
                }
                return filtered_item
        return item

    def transform_response_api_response(
        self,
        model: str,
        raw_response: httpx.Response,
        logging_obj: LiteLLMLoggingObj,
    ) -> ResponsesAPIResponse:
        """No transform applied since outputs are in OpenAI spec already"""
        try:
            logging_obj.post_call(
                original_response=raw_response.text,
                additional_args={"complete_input_dict": {}},
            )
            raw_response_json = raw_response.json()
            raw_response_json["created_at"] = _safe_convert_created_field(
                raw_response_json["created_at"]
            )
        except Exception:
            raise OpenAIError(
                message=raw_response.text, status_code=raw_response.status_code
            )
        raw_response_headers = dict(raw_response.headers)
        processed_headers = process_response_headers(raw_response_headers)
        try:
            response = ResponsesAPIResponse(**raw_response_json)
        except Exception:
            verbose_logger.debug(
                f"Error constructing ResponsesAPIResponse: {raw_response_json}, using model_construct"
            )
            response = ResponsesAPIResponse.model_construct(**raw_response_json)
        
        # Store processed headers in additional_headers so they get returned to the client
        response._hidden_params["additional_headers"] = processed_headers
        response._hidden_params["headers"] = raw_response_headers
        return response

    def validate_environment(
        self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
    ) -> dict:
        litellm_params = litellm_params or GenericLiteLLMParams()
        api_key = (
            litellm_params.api_key
            or litellm.api_key
            or litellm.openai_key
            or get_secret_str("OPENAI_API_KEY")
        )
        headers.update(
            {
                "Authorization": f"Bearer {api_key}",
            }
        )
        return headers

    def get_complete_url(
        self,
        api_base: Optional[str],
        litellm_params: dict,
    ) -> str:
        """
        Get the endpoint for OpenAI responses API
        """
        api_base = (
            api_base
            or litellm.api_base
            or get_secret_str("OPENAI_BASE_URL")
            or get_secret_str("OPENAI_API_BASE")
            or "https://api.openai.com/v1"
        )

        # Remove trailing slashes
        api_base = api_base.rstrip("/")

        return f"{api_base}/responses"

    def transform_streaming_response(
        self,
        model: str,
        parsed_chunk: dict,
        logging_obj: LiteLLMLoggingObj,
    ) -> ResponsesAPIStreamingResponse:
        """
        Transform a parsed streaming response chunk into a ResponsesAPIStreamingResponse
        """
        # Convert the dictionary to a properly typed ResponsesAPIStreamingResponse
        verbose_logger.debug("Raw OpenAI Chunk=%s", parsed_chunk)
        event_type = str(parsed_chunk.get("type"))
        event_pydantic_model = OpenAIResponsesAPIConfig.get_event_model_class(
            event_type=event_type
        )
        # Some OpenAI-compatible providers send error.code: null; coalesce so validation succeeds.
        try:
            error_obj = parsed_chunk.get("error")
            if isinstance(error_obj, dict) and error_obj.get("code") is None:
                parsed_chunk = dict(parsed_chunk)
                parsed_chunk["error"] = dict(error_obj)
                parsed_chunk["error"]["code"] = "unknown_error"
        except Exception:
            verbose_logger.debug("Failed to coalesce error.code in parsed_chunk")

        try:
            return event_pydantic_model(**parsed_chunk)
        except ValidationError:
            verbose_logger.debug(
                "Pydantic validation failed for %s with chunk %s, "
                "falling back to model_construct",
                event_pydantic_model.__name__,
                parsed_chunk,
            )
            return event_pydantic_model.model_construct(**parsed_chunk)

    @staticmethod
    def get_event_model_class(event_type: str) -> Any:
        """
        Returns the appropriate event model class based on the event type.

        Args:
            event_type (str): The type of event from the response chunk

        Returns:
            Any: The corresponding event model class

        Raises:
            ValueError: If the event type is unknown
        """
        event_models = {
            ResponsesAPIStreamEvents.RESPONSE_CREATED: ResponseCreatedEvent,
            ResponsesAPIStreamEvents.RESPONSE_IN_PROGRESS: ResponseInProgressEvent,
            ResponsesAPIStreamEvents.RESPONSE_COMPLETED: ResponseCompletedEvent,
            ResponsesAPIStreamEvents.RESPONSE_FAILED: ResponseFailedEvent,
            ResponsesAPIStreamEvents.RESPONSE_INCOMPLETE: ResponseIncompleteEvent,
            ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED: OutputItemAddedEvent,
            ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE: OutputItemDoneEvent,
            ResponsesAPIStreamEvents.CONTENT_PART_ADDED: ContentPartAddedEvent,
            ResponsesAPIStreamEvents.CONTENT_PART_DONE: ContentPartDoneEvent,
            ResponsesAPIStreamEvents.OUTPUT_TEXT_DELTA: OutputTextDeltaEvent,
            ResponsesAPIStreamEvents.OUTPUT_TEXT_ANNOTATION_ADDED: OutputTextAnnotationAddedEvent,
            ResponsesAPIStreamEvents.OUTPUT_TEXT_DONE: OutputTextDoneEvent,
            ResponsesAPIStreamEvents.REFUSAL_DELTA: RefusalDeltaEvent,
            ResponsesAPIStreamEvents.REFUSAL_DONE: RefusalDoneEvent,
            ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DELTA: FunctionCallArgumentsDeltaEvent,
            ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DONE: FunctionCallArgumentsDoneEvent,
            ResponsesAPIStreamEvents.FILE_SEARCH_CALL_IN_PROGRESS: FileSearchCallInProgressEvent,
            ResponsesAPIStreamEvents.FILE_SEARCH_CALL_SEARCHING: FileSearchCallSearchingEvent,
            ResponsesAPIStreamEvents.FILE_SEARCH_CALL_COMPLETED: FileSearchCallCompletedEvent,
            ResponsesAPIStreamEvents.WEB_SEARCH_CALL_IN_PROGRESS: WebSearchCallInProgressEvent,
            ResponsesAPIStreamEvents.WEB_SEARCH_CALL_SEARCHING: WebSearchCallSearchingEvent,
            ResponsesAPIStreamEvents.WEB_SEARCH_CALL_COMPLETED: WebSearchCallCompletedEvent,
            ResponsesAPIStreamEvents.MCP_LIST_TOOLS_IN_PROGRESS: MCPListToolsInProgressEvent,
            ResponsesAPIStreamEvents.MCP_LIST_TOOLS_COMPLETED: MCPListToolsCompletedEvent,
            ResponsesAPIStreamEvents.MCP_LIST_TOOLS_FAILED: MCPListToolsFailedEvent,
            ResponsesAPIStreamEvents.MCP_CALL_IN_PROGRESS: MCPCallInProgressEvent,
            ResponsesAPIStreamEvents.MCP_CALL_ARGUMENTS_DELTA: MCPCallArgumentsDeltaEvent,
            ResponsesAPIStreamEvents.MCP_CALL_ARGUMENTS_DONE: MCPCallArgumentsDoneEvent,
            ResponsesAPIStreamEvents.MCP_CALL_COMPLETED: MCPCallCompletedEvent,
            ResponsesAPIStreamEvents.MCP_CALL_FAILED: MCPCallFailedEvent,
            ResponsesAPIStreamEvents.IMAGE_GENERATION_PARTIAL_IMAGE: ImageGenerationPartialImageEvent,
            ResponsesAPIStreamEvents.ERROR: ErrorEvent,
            # Shell tool events: passthrough as GenericEvent so payload is preserved
            ResponsesAPIStreamEvents.SHELL_CALL_IN_PROGRESS: GenericEvent,
            ResponsesAPIStreamEvents.SHELL_CALL_COMPLETED: GenericEvent,
            ResponsesAPIStreamEvents.SHELL_CALL_OUTPUT: GenericEvent,
        }

        model_class = event_models.get(cast(ResponsesAPIStreamEvents, event_type))
        if not model_class:
            return GenericEvent

        return model_class

    def should_fake_stream(
        self,
        model: Optional[str],
        stream: Optional[bool],
        custom_llm_provider: Optional[str] = None,
    ) -> bool:
        if stream is not True:
            return False
        if model is not None:
            try:
                if (
                    litellm.utils.supports_native_streaming(
                        model=model,
                        custom_llm_provider=custom_llm_provider,
                    )
                    is False
                ):
                    return True
            except Exception as e:
                verbose_logger.debug(
                    f"Error getting model info in OpenAIResponsesAPIConfig: {e}"
                )
        return False

    def supports_native_websocket(self) -> bool:
        """OpenAI supports native WebSocket for Responses API"""
        return True

    #########################################################
    ########## DELETE RESPONSE API TRANSFORMATION ##############
    #########################################################
    def transform_delete_response_api_request(
        self,
        response_id: str,
        api_base: str,
        litellm_params: GenericLiteLLMParams,
        headers: dict,
    ) -> Tuple[str, Dict]:
        """
        Transform the delete response API request into a URL and data

        OpenAI API expects the following request
        - DELETE /v1/responses/{response_id}
        """
        url = f"{api_base}/{response_id}"
        data: Dict = {}
        return url, data

    def transform_delete_response_api_response(
        self,
        raw_response: httpx.Response,
        logging_obj: LiteLLMLoggingObj,
    ) -> DeleteResponseResult:
        """
        Transform the delete response API response into a DeleteResponseResult
        """
        try:
            raw_response_json = raw_response.json()
        except Exception:
            raise OpenAIError(
                message=raw_response.text, status_code=raw_response.status_code
            )
        return DeleteResponseResult(**raw_response_json)

    #########################################################
    ########## GET RESPONSE API TRANSFORMATION ###############
    #########################################################
    def transform_get_response_api_request(
        self,
        response_id: str,
        api_base: str,
        litellm_params: GenericLiteLLMParams,
        headers: dict,
    ) -> Tuple[str, Dict]:
        """
        Transform the get response API request into a URL and data

        OpenAI API expects the following request
        - GET /v1/responses/{response_id}
        """
        url = f"{api_base}/{response_id}"
        data: Dict = {}
        return url, data

    def transform_get_response_api_response(
        self,
        raw_response: httpx.Response,
        logging_obj: LiteLLMLoggingObj,
    ) -> ResponsesAPIResponse:
        """
        Transform the get response API response into a ResponsesAPIResponse
        """        
        try:
            raw_response_json = raw_response.json()
        except Exception:
            raise OpenAIError(
                message=raw_response.text, status_code=raw_response.status_code
            )
        raw_response_headers = dict(raw_response.headers)
        processed_headers = process_response_headers(raw_response_headers)
        response = ResponsesAPIResponse(**raw_response_json)
        response._hidden_params["additional_headers"] = processed_headers
        response._hidden_params["headers"] = raw_response_headers
        
        return response

    #########################################################
    ########## LIST INPUT ITEMS TRANSFORMATION #############
    #########################################################
    def transform_list_input_items_request(
        self,
        response_id: str,
        api_base: str,
        litellm_params: GenericLiteLLMParams,
        headers: dict,
        after: Optional[str] = None,
        before: Optional[str] = None,
        include: Optional[List[str]] = None,
        limit: int = 20,
        order: Literal["asc", "desc"] = "desc",
    ) -> Tuple[str, Dict]:
        url = f"{api_base}/{response_id}/input_items"
        params: Dict[str, Any] = {}
        if after is not None:
            params["after"] = after
        if before is not None:
            params["before"] = before
        if include:
            params["include"] = ",".join(include)
        if limit is not None:
            params["limit"] = limit
        if order is not None:
            params["order"] = order
        return url, params

    def transform_list_input_items_response(
        self,
        raw_response: httpx.Response,
        logging_obj: LiteLLMLoggingObj,
    ) -> Dict:
        try:
            return raw_response.json()
        except Exception:
            raise OpenAIError(
                message=raw_response.text, status_code=raw_response.status_code
            )

    #########################################################
    ########## CANCEL RESPONSE API TRANSFORMATION ##########
    #########################################################
    def transform_cancel_response_api_request(
        self,
        response_id: str,
        api_base: str,
        litellm_params: GenericLiteLLMParams,
        headers: dict,
    ) -> Tuple[str, Dict]:
        """
        Transform the cancel response API request into a URL and data

        OpenAI API expects the following request
        - POST /v1/responses/{response_id}/cancel
        """
        url = f"{api_base}/{response_id}/cancel"
        data: Dict = {}
        return url, data

    def transform_cancel_response_api_response(
        self,
        raw_response: httpx.Response,
        logging_obj: LiteLLMLoggingObj,
    ) -> ResponsesAPIResponse:
        """
        Transform the cancel response API response into a ResponsesAPIResponse
        """
        try:
            raw_response_json = raw_response.json()
        except Exception:
            raise OpenAIError(
                message=raw_response.text, status_code=raw_response.status_code
            )
        raw_response_headers = dict(raw_response.headers)
        processed_headers = process_response_headers(raw_response_headers)
        
        response = ResponsesAPIResponse(**raw_response_json)
        response._hidden_params["additional_headers"] = processed_headers
        response._hidden_params["headers"] = raw_response_headers
        
        return response

    #########################################################
    ########## COMPACT RESPONSE API TRANSFORMATION ##########
    #########################################################
    def transform_compact_response_api_request(
        self,
        model: str,
        input: Union[str, ResponseInputParam],
        response_api_optional_request_params: Dict,
        api_base: str,
        litellm_params: GenericLiteLLMParams,
        headers: dict,
    ) -> Tuple[str, Dict]:
        """
        Transform the compact response API request into a URL and data

        OpenAI API expects the following request
        - POST /v1/responses/compact
        """
        # Preserve query params (e.g., api-version) while appending /compact.
        parsed_url = httpx.URL(api_base)
        compact_path = parsed_url.path.rstrip("/") + "/compact"
        url = str(parsed_url.copy_with(path=compact_path))
        
        input = self._validate_input_param(input)
        data = dict(
            ResponsesAPIRequestParams(
                model=model, input=input, **response_api_optional_request_params
            )
        )
        
        return url, data

    def transform_compact_response_api_response(
        self,
        raw_response: httpx.Response,
        logging_obj: LiteLLMLoggingObj,
    ) -> ResponsesAPIResponse:
        """
        Transform the compact response API response into a ResponsesAPIResponse
        """
        try:
            logging_obj.post_call(
                original_response=raw_response.text,
                additional_args={"complete_input_dict": {}},
            )
            raw_response_json = raw_response.json()
            raw_response_json["created_at"] = _safe_convert_created_field(
                raw_response_json["created_at"]
            )
        except Exception:
            raise OpenAIError(
                message=raw_response.text, status_code=raw_response.status_code
            )
        raw_response_headers = dict(raw_response.headers)
        processed_headers = process_response_headers(raw_response_headers)
        
        try:
            response = ResponsesAPIResponse(**raw_response_json)
        except Exception:
            verbose_logger.debug(
                f"Error constructing ResponsesAPIResponse: {raw_response_json}, using model_construct"
            )
            response = ResponsesAPIResponse.model_construct(**raw_response_json)
        
        response._hidden_params["additional_headers"] = processed_headers
        response._hidden_params["headers"] = raw_response_headers
        
        return response
