import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import (
    get_async_httpx_client,
    httpxSpecialProvider,
)

if TYPE_CHECKING:
    from litellm.types.llms.openai import AllMessageValues

# Azure Content Safety APIs have a 10,000 character limit per request.
AZURE_CONTENT_SAFETY_MAX_TEXT_LENGTH = 10000


class AzureGuardrailBase:
    """
    Base class for Azure guardrails.

    Provides shared initialisation (API credentials, HTTP client) and
    utilities (text splitting, authenticated POST) used by all Azure
    Content Safety guardrails.
    """

    def __init__(
        self,
        api_key: str,
        api_base: str,
        **kwargs: Any,
    ):
        # Forward remaining kwargs to the next class in the MRO
        # (typically CustomGuardrail).
        super().__init__(**kwargs)

        self.async_handler = get_async_httpx_client(
            llm_provider=httpxSpecialProvider.GuardrailCallback
        )
        self.api_key = api_key
        self.api_base = api_base
        self.api_version: str = kwargs.get("api_version") or "2024-09-01"

    async def _post_to_content_safety(
        self, endpoint_path: str, request_body: Dict[str, Any]
    ) -> Dict[str, Any]:
        """POST to an Azure Content Safety endpoint with standard auth headers.

        Args:
            endpoint_path: The API action, e.g. ``"text:shieldPrompt"`` or
                ``"text:analyze"``.
            request_body: JSON-serialisable request payload.

        Returns:
            Parsed JSON response dict.
        """
        url = f"{self.api_base}/contentsafety/{endpoint_path}?api-version={self.api_version}"
        headers = {
            "Ocp-Apim-Subscription-Key": self.api_key,
            "Content-Type": "application/json",
        }

        verbose_proxy_logger.debug(
            "Azure Content Safety request [%s]: %s", endpoint_path, request_body
        )
        response = await self.async_handler.post(
            url=url,
            headers=headers,
            json=request_body,
        )
        response_json: Dict[str, Any] = response.json()
        verbose_proxy_logger.debug(
            "Azure Content Safety response [%s]: %s", endpoint_path, response_json
        )
        return response_json

    @staticmethod
    def split_text_by_words(text: str, max_length: int) -> List[str]:
        """
        Split text into chunks at word boundaries without breaking words.

        Always returns at least one chunk.  Short text (≤ max_length) is
        returned as a single-element list so callers can use a uniform
        loop without branching on length.

        Args:
            text: The text to split
            max_length: Maximum character length of each chunk

        Returns:
            List of text chunks, each not exceeding max_length
        """
        if len(text) <= max_length:
            return [text]

        # Tokenize into alternating non-whitespace and whitespace runs so
        # that original newlines, tabs, and multiple spaces are preserved
        # within each chunk.
        tokens = re.findall(r"\S+|\s+", text)

        chunks: List[str] = []
        current_chunk = ""

        for token in tokens:
            # Would appending this token exceed the limit?
            if len(current_chunk) + len(token) <= max_length:
                current_chunk += token
            else:
                # Flush whatever we have accumulated so far
                if current_chunk:
                    chunks.append(current_chunk)
                    current_chunk = ""

                # Force-split any single token longer than max_length
                while len(token) > max_length:
                    chunks.append(token[:max_length])
                    token = token[max_length:]

                current_chunk = token

        if current_chunk:
            chunks.append(current_chunk)

        return chunks

    def get_user_prompt(self, messages: List["AllMessageValues"]) -> Optional[str]:
        """
        Get the last consecutive block of messages from the user.

        Example:
        messages = [
            {"role": "user", "content": "Hello, how are you?"},
            {"role": "assistant", "content": "I'm good, thank you!"},
            {"role": "user", "content": "What is the weather in Tokyo?"},
        ]
        get_user_prompt(messages) -> "What is the weather in Tokyo?"
        """
        from litellm.litellm_core_utils.prompt_templates.common_utils import (
            convert_content_list_to_str,
        )

        if not messages:
            return None

        # Iterate from the end to find the last consecutive block of user messages
        user_messages = []
        for message in reversed(messages):
            if message.get("role") == "user":
                user_messages.append(message)
            else:
                # Stop when we hit a non-user message
                break

        if not user_messages:
            return None

        # Reverse to get the messages in chronological order
        user_messages.reverse()

        user_prompt = ""
        for message in user_messages:
            text_content = convert_content_list_to_str(message)
            user_prompt += text_content + "\n"

        result = user_prompt.strip()
        return result if result else None
