# litellm/proxy/guardrails/guardrail_registry.py

import importlib
import os
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Type, cast

import litellm
from litellm import Router
from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy.guardrails.guardrail_hooks.grayswan import GraySwanGuardrail
from litellm.proxy.guardrails.guardrail_hooks.grayswan import (
    initialize_guardrail as initialize_grayswan,
)
from litellm.proxy.types_utils.utils import get_instance_fn
from litellm.proxy.utils import PrismaClient
from litellm.secret_managers.main import get_secret
from litellm.types.guardrails import (
    Guardrail,
    GuardrailEventHooks,
    LakeraCategoryThresholds,
    LitellmParams,
    SupportedGuardrailIntegrations,
)

from .guardrail_initializers import (
    initialize_bedrock,
    initialize_hide_secrets,
    initialize_lakera,
    initialize_lakera_v2,
    initialize_presidio,
    initialize_tool_permission,
)

guardrail_initializer_registry = {
    SupportedGuardrailIntegrations.BEDROCK.value: initialize_bedrock,
    SupportedGuardrailIntegrations.LAKERA.value: initialize_lakera,
    SupportedGuardrailIntegrations.LAKERA_V2.value: initialize_lakera_v2,
    SupportedGuardrailIntegrations.PRESIDIO.value: initialize_presidio,
    SupportedGuardrailIntegrations.HIDE_SECRETS.value: initialize_hide_secrets,
    SupportedGuardrailIntegrations.TOOL_PERMISSION.value: initialize_tool_permission,
    SupportedGuardrailIntegrations.GRAYSWAN.value: initialize_grayswan,
}

guardrail_class_registry: Dict[str, Type[CustomGuardrail]] = {
    SupportedGuardrailIntegrations.GRAYSWAN.value: GraySwanGuardrail
}


def get_guardrail_initializer_from_hooks():
    """
    Get guardrail initializers by discovering them from the guardrail_hooks directory structure.

    Scans the guardrail_hooks directory for subdirectories containing __init__.py files
    with either guardrail_initializer_registry or initialize_guardrail functions.

    Returns:
        Dict[str, Callable]: A dictionary mapping guardrail types to their initializer functions
    """
    discovered_initializers = {}

    try:
        # Get the path to the guardrail_hooks directory
        current_dir = os.path.dirname(__file__)
        hooks_dir = os.path.join(current_dir, "guardrail_hooks")

        if not os.path.exists(hooks_dir):
            verbose_proxy_logger.debug("guardrail_hooks directory not found")
            return discovered_initializers

        # Scan each subdirectory in guardrail_hooks
        for item in os.listdir(hooks_dir):
            item_path = os.path.join(hooks_dir, item)

            # Skip files and __pycache__ directories
            if not os.path.isdir(item_path) or item.startswith("__"):
                continue

            # Check if the directory has an __init__.py file
            init_file = os.path.join(item_path, "__init__.py")
            if not os.path.exists(init_file):
                continue

            module_path = f"litellm.proxy.guardrails.guardrail_hooks.{item}"
            try:
                # Import the module
                verbose_proxy_logger.debug(f"Discovering guardrails in: {module_path}")

                module = importlib.import_module(module_path)

                # Check for guardrail_initializer_registry dictionary
                if hasattr(module, "guardrail_initializer_registry"):
                    registry = getattr(module, "guardrail_initializer_registry")
                    if isinstance(registry, dict):
                        discovered_initializers.update(registry)
                        verbose_proxy_logger.debug(
                            f"Found guardrail_initializer_registry in {module_path}: {list(registry.keys())}"
                        )

                # Check for standalone initialize_guardrail function (fallback for directory-based guardrails)
                elif hasattr(module, "initialize_guardrail"):
                    # For directories with just initialize_guardrail, use the directory name as the key
                    initialize_fn = getattr(module, "initialize_guardrail")
                    discovered_initializers[item] = initialize_fn
                    verbose_proxy_logger.debug(
                        f"Found initialize_guardrail function in {module_path}"
                    )

            except ImportError as e:
                verbose_proxy_logger.error(f"Could not import {module_path}: {e}")
                continue
            except Exception as e:
                verbose_proxy_logger.error(f"Error processing {module_path}: {e}")
                continue

        verbose_proxy_logger.debug(
            f"Discovered {len(discovered_initializers)} guardrail initializers: {list(discovered_initializers.keys())}"
        )

    except Exception as e:
        verbose_proxy_logger.error(f"Error discovering guardrail initializers: {e}")

    return discovered_initializers


def get_guardrail_class_from_hooks():
    """
    Get guardrail classes by discovering them from the guardrail_hooks directory structure.
    """
    """
    Get guardrail initializers by discovering them from the guardrail_hooks directory structure.

    Scans the guardrail_hooks directory for subdirectories containing __init__.py files
    with either guardrail_initializer_registry or initialize_guardrail functions.

    Returns:
        Dict[str, Callable]: A dictionary mapping guardrail types to their initializer functions
    """
    discovered_classes = {}

    try:
        # Get the path to the guardrail_hooks directory
        current_dir = os.path.dirname(__file__)
        hooks_dir = os.path.join(current_dir, "guardrail_hooks")

        if not os.path.exists(hooks_dir):
            verbose_proxy_logger.debug("guardrail_hooks directory not found")
            return discovered_classes

        # Scan each subdirectory in guardrail_hooks
        for item in os.listdir(hooks_dir):
            item_path = os.path.join(hooks_dir, item)

            # Skip files and __pycache__ directories
            if not os.path.isdir(item_path) or item.startswith("__"):
                continue

            # Check if the directory has an __init__.py file
            init_file = os.path.join(item_path, "__init__.py")

            if not os.path.exists(init_file):
                continue

            module_path = f"litellm.proxy.guardrails.guardrail_hooks.{item}"

            try:
                # Import the module
                verbose_proxy_logger.debug(f"Discovering guardrails in: {module_path}")

                module = importlib.import_module(module_path)

                # Check for guardrail_initializer_registry dictionary
                if hasattr(module, "guardrail_class_registry"):
                    registry = getattr(module, "guardrail_class_registry")
                    if isinstance(registry, dict):
                        discovered_classes.update(registry)

            except ImportError as e:
                verbose_proxy_logger.debug(f"Could not import {module_path}: {e}")
                continue
            except Exception as e:
                verbose_proxy_logger.exception(f"Error processing {module_path}: {e}")
                continue

    except Exception as e:
        verbose_proxy_logger.error(f"Error discovering guardrail initializers: {e}")

    return discovered_classes


guardrail_class_registry.update(get_guardrail_class_from_hooks())


# Merge with dynamically discovered guardrail initializers
_discovered_initializers = get_guardrail_initializer_from_hooks()

guardrail_initializer_registry.update(_discovered_initializers)


class GuardrailRegistry:
    """
    Registry for guardrails

    Handles adding, removing, and getting guardrails in DB + in memory
    """

    def __init__(self):
        pass

    ###########################################################
    ########### In memory management helpers for guardrails ###########
    ############################################################
    def get_initialized_guardrail_callback(
        self, guardrail_name: str
    ) -> Optional[CustomGuardrail]:
        """
        Returns the initialized guardrail callback for a given guardrail name
        """
        active_guardrails = (
            litellm.logging_callback_manager.get_custom_loggers_for_type(
                callback_type=CustomGuardrail
            )
        )
        for active_guardrail in active_guardrails:
            if isinstance(active_guardrail, CustomGuardrail):
                if active_guardrail.guardrail_name == guardrail_name:
                    return active_guardrail
        return None

    ###########################################################
    ########### DB management helpers for guardrails ###########
    ############################################################
    async def add_guardrail_to_db(
        self, guardrail: Guardrail, prisma_client: PrismaClient
    ):
        """
        Add a guardrail to the database
        """
        try:
            guardrail_name = guardrail.get("guardrail_name")
            # Properly serialize LitellmParams Pydantic model to dict
            litellm_params_obj: Any = guardrail.get("litellm_params", {})
            if hasattr(litellm_params_obj, "model_dump"):
                litellm_params_dict = litellm_params_obj.model_dump()
            else:
                litellm_params_dict = (
                    dict(litellm_params_obj) if litellm_params_obj else {}
                )
            litellm_params: str = safe_dumps(litellm_params_dict)
            guardrail_info: str = safe_dumps(guardrail.get("guardrail_info", {}))

            # Create guardrail in DB
            created_guardrail = await prisma_client.db.litellm_guardrailstable.create(
                data={
                    "guardrail_name": guardrail_name,
                    "litellm_params": litellm_params,
                    "guardrail_info": guardrail_info,
                    "created_at": datetime.now(timezone.utc),
                    "updated_at": datetime.now(timezone.utc),
                }
            )

            # Add guardrail_id to the returned guardrail object
            guardrail_dict = dict(guardrail)
            guardrail_dict["guardrail_id"] = created_guardrail.guardrail_id

            return guardrail_dict
        except Exception as e:
            raise Exception(f"Error adding guardrail to DB: {str(e)}")

    async def delete_guardrail_from_db(
        self, guardrail_id: str, prisma_client: PrismaClient
    ):
        """
        Delete a guardrail from the database
        """
        try:
            # Delete from DB
            await prisma_client.db.litellm_guardrailstable.delete(
                where={"guardrail_id": guardrail_id}
            )

            return {"message": f"Guardrail {guardrail_id} deleted successfully"}
        except Exception as e:
            raise Exception(f"Error deleting guardrail from DB: {str(e)}")

    async def update_guardrail_in_db(
        self, guardrail_id: str, guardrail: Guardrail, prisma_client: PrismaClient
    ):
        """
        Update a guardrail in the database
        """
        try:
            guardrail_name = guardrail.get("guardrail_name")
            # Properly serialize LitellmParams Pydantic model to dict
            litellm_params_obj: Any = guardrail.get("litellm_params", {})
            if hasattr(litellm_params_obj, "model_dump"):
                litellm_params_dict = litellm_params_obj.model_dump()
            else:
                litellm_params_dict = (
                    dict(litellm_params_obj) if litellm_params_obj else {}
                )
            litellm_params: str = safe_dumps(litellm_params_dict)
            guardrail_info: str = safe_dumps(guardrail.get("guardrail_info", {}))

            # Update in DB
            updated_guardrail = await prisma_client.db.litellm_guardrailstable.update(
                where={"guardrail_id": guardrail_id},
                data={
                    "guardrail_name": guardrail_name,
                    "litellm_params": litellm_params,
                    "guardrail_info": guardrail_info,
                    "updated_at": datetime.now(timezone.utc),
                },
            )

            # Convert to dict and return
            return dict(updated_guardrail)
        except Exception as e:
            raise Exception(f"Error updating guardrail in DB: {str(e)}")

    @staticmethod
    async def get_all_guardrails_from_db(
        prisma_client: PrismaClient,
    ) -> List[Guardrail]:
        """
        Get all active guardrails from the database.
        Only rows with status == "active" are returned (pending_review and rejected are excluded).
        """
        try:
            guardrails_from_db = (
                await prisma_client.db.litellm_guardrailstable.find_many(
                    where={"status": "active"},
                    order={"created_at": "desc"},
                )
            )

            guardrails: List[Guardrail] = []
            for guardrail in guardrails_from_db:
                guardrails.append(Guardrail(**(dict(guardrail))))  # type: ignore

            return guardrails
        except Exception as e:
            raise Exception(f"Error getting guardrails from DB: {str(e)}")

    async def get_guardrail_by_id_from_db(
        self, guardrail_id: str, prisma_client: PrismaClient
    ) -> Optional[Guardrail]:
        """
        Get a guardrail by its ID from the database
        """
        try:
            guardrail = await prisma_client.db.litellm_guardrailstable.find_unique(
                where={"guardrail_id": guardrail_id}
            )

            if not guardrail:
                return None

            return Guardrail(**(dict(guardrail)))  # type: ignore
        except Exception as e:
            raise Exception(f"Error getting guardrail from DB: {str(e)}")

    async def get_guardrail_by_name_from_db(
        self, guardrail_name: str, prisma_client: PrismaClient
    ) -> Optional[Guardrail]:
        """
        Get a guardrail by its name from the database
        """
        try:
            guardrail = await prisma_client.db.litellm_guardrailstable.find_unique(
                where={"guardrail_name": guardrail_name}
            )

            if not guardrail:
                return None

            return Guardrail(**(dict(guardrail)))  # type: ignore
        except Exception as e:
            raise Exception(f"Error getting guardrail from DB: {str(e)}")


class InMemoryGuardrailHandler:
    """
    Class that handles initializing guardrails and adding them to the CallbackManager
    """

    def __init__(self):
        self.IN_MEMORY_GUARDRAILS: Dict[str, Guardrail] = {}
        """
        Guardrail id to Guardrail object mapping
        """

        self.guardrail_id_to_custom_guardrail: Dict[str, Optional[CustomGuardrail]] = {}
        """
        Guardrail id to CustomGuardrail object mapping
        """

    def initialize_guardrail(
        self,
        guardrail: Guardrail,
        config_file_path: Optional[str] = None,
        llm_router: Optional["Router"] = None,
    ) -> Optional[Guardrail]:
        """
        Initialize a guardrail from a dictionary and add it to the litellm callback manager

        Returns a Guardrail object if the guardrail is initialized successfully
        """
        guardrail_id = guardrail.get("guardrail_id") or str(uuid.uuid4())
        guardrail["guardrail_id"] = guardrail_id
        if guardrail_id in self.IN_MEMORY_GUARDRAILS:
            verbose_proxy_logger.debug(
                "guardrail_id already exists in IN_MEMORY_GUARDRAILS"
            )
            return self.IN_MEMORY_GUARDRAILS[guardrail_id]

        custom_guardrail_callback: Optional[CustomGuardrail] = None
        litellm_params_data = guardrail["litellm_params"]
        verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)

        if isinstance(litellm_params_data, dict):
            litellm_params = LitellmParams(**litellm_params_data)
        else:
            litellm_params = litellm_params_data

        if (
            "category_thresholds" in litellm_params_data
            and litellm_params_data["category_thresholds"]
        ):
            lakera_category_thresholds = LakeraCategoryThresholds(
                **litellm_params_data["category_thresholds"]
            )
            litellm_params.category_thresholds = lakera_category_thresholds

        if litellm_params.api_key and litellm_params.api_key.startswith("os.environ/"):
            litellm_params.api_key = str(get_secret(litellm_params.api_key))

        if litellm_params.api_base and litellm_params.api_base.startswith(
            "os.environ/"
        ):
            litellm_params.api_base = str(get_secret(litellm_params.api_base))

        guardrail_type = litellm_params.guardrail

        if guardrail_type is None:
            raise ValueError("guardrail_type is required")

        initializer = guardrail_initializer_registry.get(guardrail_type)

        if initializer:
            # Try to call with llm_router first, fall back to without if it fails
            import inspect

            sig = inspect.signature(initializer)
            if "llm_router" in sig.parameters:
                custom_guardrail_callback = initializer(
                    litellm_params, guardrail, llm_router  # type: ignore
                )
            else:
                custom_guardrail_callback = initializer(litellm_params, guardrail)
        elif isinstance(guardrail_type, str) and "." in guardrail_type:
            custom_guardrail_callback = self.initialize_custom_guardrail(
                guardrail=cast(dict, guardrail),
                guardrail_type=guardrail_type,
                litellm_params=litellm_params,
                config_file_path=config_file_path,
            )
        else:
            raise ValueError(f"Unsupported guardrail: {guardrail_type}")

        parsed_guardrail = Guardrail(
            guardrail_id=guardrail.get("guardrail_id"),
            guardrail_name=guardrail["guardrail_name"],
            litellm_params=litellm_params,
        )

        # store references to the guardrail in memory
        self.IN_MEMORY_GUARDRAILS[guardrail_id] = parsed_guardrail
        self.guardrail_id_to_custom_guardrail[guardrail_id] = custom_guardrail_callback

        return parsed_guardrail

    def initialize_custom_guardrail(
        self,
        guardrail: Dict,
        guardrail_type: str,
        litellm_params: LitellmParams,
        config_file_path: Optional[str] = None,
    ) -> Optional[CustomGuardrail]:
        """
        Initialize a Custom Guardrail from a python file or module path

        This initializes it by adding it to the litellm callback manager
        """
        if not config_file_path:
            raise Exception(
                "GuardrailsAIException - Please pass the config_file_path to initialize_guardrails_v2"
            )

        verbose_proxy_logger.debug(
            "Initializing custom guardrail: %s",
            guardrail_type,
        )

        _guardrail_class = get_instance_fn(guardrail_type, config_file_path=config_file_path)

        mode = litellm_params.mode
        if mode is None:
            raise ValueError(
                f"mode is required for guardrail {guardrail_type} please set mode to one of the following: {', '.join(GuardrailEventHooks)}"
            )

        default_on = litellm_params.default_on

        # Extract additional params from litellm_params to pass to custom guardrail
        # This matches the behavior of other guardrail initializers (e.g., initialize_lakera)
        # and aligns with the documented behavior for custom guardrails
        if hasattr(litellm_params, "model_dump"):
            extra_params = litellm_params.model_dump(exclude_none=True)
        else:
            extra_params = dict(litellm_params) if litellm_params else {}

        # Remove params that are handled explicitly or are internal
        for key in ["guardrail", "mode", "default_on"]:
            extra_params.pop(key, None)

        _guardrail_callback = _guardrail_class(
            guardrail_name=guardrail["guardrail_name"],
            event_hook=mode,
            default_on=default_on,
            **extra_params,
        )
        litellm.logging_callback_manager.add_litellm_callback(_guardrail_callback)  # type: ignore

        return _guardrail_callback

    def update_in_memory_guardrail(
        self, guardrail_id: str, guardrail: Guardrail
    ) -> None:
        """
        Update a guardrail in memory

        - updates the guardrail in memory
        - updates the guardrail params in litellm.callback_manager
        """
        self.IN_MEMORY_GUARDRAILS[guardrail_id] = guardrail

        custom_guardrail_callback = self.guardrail_id_to_custom_guardrail.get(
            guardrail_id
        )
        if custom_guardrail_callback:
            updated_litellm_params = cast(
                LitellmParams, guardrail.get("litellm_params", {})
            )
            custom_guardrail_callback.update_in_memory_litellm_params(
                litellm_params=updated_litellm_params
            )

    def delete_in_memory_guardrail(self, guardrail_id: str) -> None:
        """
        Delete a guardrail in memory and remove from litellm callbacks.
        """
        # Remove from in-memory storage
        self.IN_MEMORY_GUARDRAILS.pop(guardrail_id, None)

        # Remove the callback from litellm.callbacks
        custom_guardrail_callback = self.guardrail_id_to_custom_guardrail.pop(
            guardrail_id, None
        )
        if custom_guardrail_callback:
            litellm.logging_callback_manager.remove_callback_from_list_by_object(
                callback_list=litellm.callbacks,
                obj=custom_guardrail_callback,
                require_self=False,
            )

    def list_in_memory_guardrails(self) -> List[Guardrail]:
        """
        List all guardrails in memory
        """
        return list(self.IN_MEMORY_GUARDRAILS.values())

    def get_guardrail_by_id(self, guardrail_id: str) -> Optional[Guardrail]:
        """
        Get a guardrail by its ID from memory
        """
        return self.IN_MEMORY_GUARDRAILS.get(guardrail_id)

    def _has_guardrail_params_changed(
        self, guardrail_id: str, new_guardrail: Guardrail
    ) -> bool:
        """
        Check if guardrail params or name have changed compared to in-memory version.
        Returns True if params/name changed or guardrail doesn't exist in memory.
        """
        existing = self.IN_MEMORY_GUARDRAILS.get(guardrail_id)
        if existing is None:
            return True

        # Compare guardrail_name
        if existing.get("guardrail_name") != new_guardrail.get("guardrail_name"):
            return True

        # Compare litellm_params
        existing_params = existing.get("litellm_params")
        new_params = new_guardrail.get("litellm_params")

        # Convert to dicts for comparison
        existing_dict = (
            existing_params.model_dump()
            if isinstance(existing_params, LitellmParams)
            else existing_params
        )
        new_dict = (
            new_params.model_dump()
            if isinstance(new_params, LitellmParams)
            else new_params
        )

        # Compare and identify specific differences
        changed_fields = {}
        if existing_dict is not None and new_dict is not None:
            all_keys = set(existing_dict.keys()) | set(new_dict.keys())
            for key in all_keys:
                old_val = existing_dict.get(key)
                new_val = new_dict.get(key)
                if old_val != new_val:
                    changed_fields[key] = {"old": old_val, "new": new_val}
        elif existing_dict != new_dict:
            changed_fields = {"litellm_params": {"old": existing_dict, "new": new_dict}}

        # Log differences if any found
        if changed_fields:
            verbose_proxy_logger.debug(
                f"Guardrail params changed. Differences: {changed_fields}"
            )

        # Return True if any fields changed
        return len(changed_fields) > 0

    def reinitialize_guardrail(
        self, guardrail: Guardrail, config_file_path: Optional[str] = None
    ) -> Optional[Guardrail]:
        """
        Force re-initialization of a guardrail even if it exists in memory.
        Removes old callback from litellm.callbacks and creates fresh instance.
        """
        guardrail_id = guardrail.get("guardrail_id")
        if not guardrail_id:
            verbose_proxy_logger.error(
                "Cannot reinitialize guardrail without guardrail_id"
            )
            return None

        # Remove from memory if exists (also removes from callbacks)
        if guardrail_id in self.IN_MEMORY_GUARDRAILS:
            self.delete_in_memory_guardrail(guardrail_id)

        # Initialize fresh (will add new callback to litellm.callbacks)
        return self.initialize_guardrail(
            guardrail=guardrail, config_file_path=config_file_path
        )

    def sync_guardrail_from_db(
        self, guardrail: Guardrail, config_file_path: Optional[str] = None
    ) -> Optional[Guardrail]:
        """
        Sync a guardrail from DB - initializes if new, re-initializes if changed.
        This is the method to call during DB polling.
        """
        guardrail_id = guardrail.get("guardrail_id")
        if not guardrail_id:
            verbose_proxy_logger.error("Cannot sync guardrail without guardrail_id")
            return None

        if self._has_guardrail_params_changed(guardrail_id, guardrail):
            guardrail_name = guardrail.get("guardrail_name", "Unknown")
            verbose_proxy_logger.info(
                f"Guardrail '{guardrail_name}' (ID: {guardrail_id}) params changed, re-initializing..."
            )
            return self.reinitialize_guardrail(
                guardrail=guardrail, config_file_path=config_file_path
            )

        return self.IN_MEMORY_GUARDRAILS.get(guardrail_id)


########################################################
# In Memory Guardrail Handler for LiteLLM Proxy
########################################################
IN_MEMORY_GUARDRAIL_HANDLER = InMemoryGuardrailHandler()
########################################################
