#!/usr/bin/env python3
"""
Azure Text Moderation Native Guardrail Integrationfor LiteLLM
"""

from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Type, Union, cast

from fastapi import HTTPException

from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import (
    CustomGuardrail,
    log_guardrail_information,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.utils import CallTypesLiteral

from .base import AzureGuardrailBase

if TYPE_CHECKING:
    from litellm.types.llms.openai import AllMessageValues
    from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_text_moderation import (
        AzureTextModerationGuardrailResponse,
    )
    from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel
    from litellm.types.utils import EmbeddingResponse, ImageResponse, ModelResponse


class AzureContentSafetyTextModerationGuardrail(AzureGuardrailBase, CustomGuardrail):
    """
    LiteLLM Built-in Guardrail for Azure Content Safety (Text Moderation).

    This guardrail scans prompts and responses using the Azure Text Moderation API to detect
    malicious content and policy violations based on severity thresholds.

    Configuration:
        guardrail_name: Name of the guardrail instance
        api_key: Azure Text Moderation API key
        api_base: Azure Text Moderation API endpoint
        default_on: Whether to enable by default
    """

    default_severity_threshold: int = 2

    def __init__(
        self,
        guardrail_name: str,
        api_key: str,
        api_base: str,
        severity_threshold: Optional[int] = None,
        severity_threshold_by_category: Optional[Dict[str, int]] = None,
        **kwargs,
    ):
        """Initialize Azure Text Moderation guardrail handler."""
        from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_text_moderation import (
            AzureTextModerationRequestBodyOptionalParams,
        )

        # AzureGuardrailBase.__init__ stores api_key, api_base, api_version,
        # async_handler and forwards the rest to CustomGuardrail.
        super().__init__(
            api_key=api_key,
            api_base=api_base,
            guardrail_name=guardrail_name,
            **kwargs,
        )

        self.optional_params_request_body: (
            AzureTextModerationRequestBodyOptionalParams
        ) = {
            "categories": kwargs.get("categories")
            or [
                "Hate",
                "Sexual",
                "SelfHarm",
                "Violence",
            ],
            "blocklistNames": cast(
                Optional[List[str]], kwargs.get("blocklistNames") or None
            ),
            "haltOnBlocklistHit": kwargs.get("haltOnBlocklistHit") or False,
            "outputType": kwargs.get("outputType") or "FourSeverityLevels",
        }

        self.severity_threshold = (
            int(severity_threshold) if severity_threshold else None
        )
        self.severity_threshold_by_category = severity_threshold_by_category

        verbose_proxy_logger.info(
            f"Initialized Azure Text Moderation Guardrail: {guardrail_name}"
        )

    @staticmethod
    def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
        from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_text_moderation import (
            AzureContentSafetyTextModerationConfigModel,
        )

        return AzureContentSafetyTextModerationConfigModel

    async def async_make_request(
        self, text: str
    ) -> "AzureTextModerationGuardrailResponse":
        """
        Make a request to the Azure Text Moderation API.

        Long texts are automatically split at word boundaries into chunks
        that respect the Azure Content Safety 10 000-character limit.  Each
        chunk is analysed independently; a severity-threshold violation in
        *any* chunk raises an HTTPException immediately.
        """
        from .base import AZURE_CONTENT_SAFETY_MAX_TEXT_LENGTH
        from litellm.types.proxy.guardrails.guardrail_hooks.azure.azure_text_moderation import (
            AzureTextModerationGuardrailRequestBody,
            AzureTextModerationGuardrailResponse,
        )

        chunks = self.split_text_by_words(
            text, AZURE_CONTENT_SAFETY_MAX_TEXT_LENGTH
        )

        last_response: Optional[AzureTextModerationGuardrailResponse] = None

        for chunk in chunks:
            request_body = AzureTextModerationGuardrailRequestBody(
                text=chunk,
                **self.optional_params_request_body,  # type: ignore[misc]
            )
            response_json = await self._post_to_content_safety(
                "text:analyze", cast(dict, request_body)
            )

            chunk_response = cast(AzureTextModerationGuardrailResponse, response_json)

            # For multi-chunk texts the callers only see the final response,
            # so we must check every intermediate chunk here to avoid silently
            # swallowing a violation that appears in an earlier chunk.
            try:
                self.check_severity_threshold(response=chunk_response)
            except HTTPException:
                verbose_proxy_logger.warning(
                    "Azure Text Moderation: Violation detected in chunk of length %d",
                    len(chunk),
                )
                raise

            last_response = chunk_response

        # chunks is always non-empty (split_text_by_words guarantees ≥1 element)
        assert last_response is not None
        return last_response

    def check_severity_threshold(
        self, response: "AzureTextModerationGuardrailResponse"
    ) -> Literal[True]:
        """
        - Check if threshold set by category
        - Check if general severity threshold set
        - If both none, use default_severity_threshold
        """

        if self.severity_threshold_by_category:
            for category in response["categoriesAnalysis"]:
                severity_category_threshold_item = (
                    self.severity_threshold_by_category.get(category["category"])
                )
                if (
                    severity_category_threshold_item is not None
                    and category["severity"] >= severity_category_threshold_item
                ):
                    raise HTTPException(
                        status_code=400,
                        detail={
                            "error": "Azure Content Safety Guardrail: {} crossed severity {}, Got severity: {}".format(
                                category["category"],
                                self.severity_threshold_by_category.get(
                                    category["category"]
                                ),
                                category["severity"],
                            )
                        },
                    )
        if self.severity_threshold:
            for category in response["categoriesAnalysis"]:
                if category["severity"] >= self.severity_threshold:
                    raise HTTPException(
                        status_code=400,
                        detail={
                            "error": "Azure Content Safety Guardrail: {} crossed severity {}, Got severity: {}".format(
                                category["category"],
                                self.severity_threshold,
                                category["severity"],
                            )
                        },
                    )
        if (
            self.severity_threshold is None
            and self.severity_threshold_by_category is None
        ):
            for category in response["categoriesAnalysis"]:
                if category["severity"] >= self.default_severity_threshold:
                    raise HTTPException(
                        status_code=400,
                        detail={
                            "error": "Azure Content Safety Guardrail: {} crossed severity {}, Got severity: {}".format(
                                category["category"],
                                self.default_severity_threshold,
                                category["severity"],
                            )
                        },
                    )
        return True

    @log_guardrail_information
    async def async_pre_call_hook(
        self,
        user_api_key_dict: "UserAPIKeyAuth",
        cache: Any,
        data: Dict[str, Any],
        call_type: CallTypesLiteral,
    ) -> Optional[Dict[str, Any]]:
        """
        Pre-call hook to scan user prompts before sending to LLM.

        Raises HTTPException if content should be blocked.
        """
        verbose_proxy_logger.info(
            "Azure Text Moderation: Running pre-call prompt scan, on call_type: %s",
            call_type,
        )
        new_messages: Optional[List[AllMessageValues]] = data.get("messages")
        if new_messages is None:
            verbose_proxy_logger.warning(
                "Azure Text Moderation: not running guardrail. No messages in data"
            )
            return data
        user_prompt = self.get_user_prompt(new_messages)

        if user_prompt:
            verbose_proxy_logger.info(
                f"Azure Text Moderation: User prompt: {user_prompt}"
            )
            await self.async_make_request(
                text=user_prompt,
            )
        else:
            verbose_proxy_logger.warning("Azure Text Moderation: No text found")
        return None

    async def async_post_call_success_hook(
        self,
        data: dict,
        user_api_key_dict: "UserAPIKeyAuth",
        response: Union[Any, "ModelResponse", "EmbeddingResponse", "ImageResponse"],
    ) -> Any:
        from litellm.types.utils import Choices, ModelResponse

        if (
            isinstance(response, ModelResponse)
            and response.choices
            and isinstance(response.choices[0], Choices)
        ):
            content = response.choices[0].message.content or ""
            await self.async_make_request(
                text=content,
            )
        return response

    async def async_post_call_streaming_hook(
        self, user_api_key_dict: UserAPIKeyAuth, response: str
    ) -> Any:
        try:
            if response is not None and len(response) > 0:
                await self.async_make_request(
                    text=response,
                )
            return response
        except HTTPException as e:
            import json

            error_returned = json.dumps({"error": e.detail})
            return f"data: {error_returned}\n\n"
