import hashlib
import json
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional

import litellm
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy.management_helpers.object_permission_utils import \
    handle_update_object_permission_common
from litellm.proxy.utils import PrismaClient
from litellm.types.agents import AgentConfig, AgentResponse, PatchAgentRequest


class AgentRegistry:
    def __init__(self):
        self.agent_list: List[AgentResponse] = []

    def reset_agent_list(self):
        self.agent_list = []

    def register_agent(self, agent_config: AgentResponse):
        self.agent_list.append(agent_config)

    def deregister_agent(self, agent_name: str):
        self.agent_list = [
            agent for agent in self.agent_list if agent.agent_name != agent_name
        ]

    def get_agent_list(self, agent_names: Optional[List[str]] = None):
        if agent_names is not None:
            return [
                agent for agent in self.agent_list if agent.agent_name in agent_names
            ]
        return self.agent_list

    def get_public_agent_list(self) -> List[AgentResponse]:
        public_agent_list: List[AgentResponse] = []
        if litellm.public_agent_groups is None:
            return public_agent_list
        for agent in self.agent_list:
            if agent.agent_id in litellm.public_agent_groups:
                public_agent_list.append(agent)
        return public_agent_list

    def _create_agent_id(self, agent_config: AgentConfig) -> str:
        return hashlib.sha256(
            json.dumps(agent_config, sort_keys=True).encode()
        ).hexdigest()

    def load_agents_from_config(self, agent_config: Optional[List[AgentConfig]] = None):
        if agent_config is None:
            return None

        for agent_config_item in agent_config:
            if not isinstance(agent_config_item, dict):
                raise ValueError("agent_config must be a list of dictionaries")

            agent_name = agent_config_item.get("agent_name")
            agent_card_params = agent_config_item.get("agent_card_params")
            if not all([agent_name, agent_card_params]):
                continue

            # create a stable hash id for config item
            config_hash = self._create_agent_id(agent_config_item)

            self.register_agent(agent_config=AgentResponse(agent_id=config_hash, **agent_config_item))  # type: ignore

    def load_agents_from_db_and_config(
        self,
        agent_config: Optional[List[AgentConfig]] = None,
        db_agents: Optional[List[Dict[str, Any]]] = None,
    ):
        self.reset_agent_list()

        if agent_config:
            for agent_config_item in agent_config:
                if not isinstance(agent_config_item, dict):
                    raise ValueError("agent_config must be a list of dictionaries")

                self.register_agent(agent_config=AgentResponse(agent_id=self._create_agent_id(agent_config_item), **agent_config_item))  # type: ignore

        if db_agents:
            for db_agent in db_agents:
                if not isinstance(db_agent, dict):
                    raise ValueError("db_agents must be a list of dictionaries")

                self.register_agent(agent_config=AgentResponse(**db_agent))  # type: ignore
        return self.agent_list

    ###########################################################
    ########### DB management helpers for agents ###########
    ############################################################
    async def add_agent_to_db(
        self, agent: AgentConfig, prisma_client: PrismaClient, created_by: str
    ) -> AgentResponse:
        """
        Add an agent to the database
        """
        try:
            agent_name = agent.get("agent_name")

            # Serialize litellm_params
            litellm_params_obj: Any = agent.get("litellm_params", {})
            if hasattr(litellm_params_obj, "model_dump"):
                litellm_params_dict = litellm_params_obj.model_dump()
            else:
                litellm_params_dict = (
                    dict(litellm_params_obj) if litellm_params_obj else {}
                )
            litellm_params: str = safe_dumps(litellm_params_dict)

            # Serialize agent_card_params
            agent_card_params_obj: Any = agent.get("agent_card_params", {})
            if hasattr(agent_card_params_obj, "model_dump"):
                agent_card_params_dict = agent_card_params_obj.model_dump()
            else:
                agent_card_params_dict = (
                    dict(agent_card_params_obj) if agent_card_params_obj else {}
                )
            agent_card_params: str = safe_dumps(agent_card_params_dict)

            # Handle object_permission (MCP tool access for agent)
            object_permission_id: Optional[str] = None
            if agent.get("object_permission") is not None:
                agent_copy = dict(agent)
                object_permission_id = await handle_update_object_permission_common(
                    agent_copy, None, prisma_client
                )

            # Serialize static_headers
            static_headers_obj = agent.get("static_headers")
            static_headers_val: Optional[str] = (
                safe_dumps(dict(static_headers_obj)) if static_headers_obj else None
            )

            extra_headers_val: Optional[List[str]] = agent.get("extra_headers")

            create_data: Dict[str, Any] = {
                "agent_name": agent_name,
                "litellm_params": litellm_params,
                "agent_card_params": agent_card_params,
                "created_by": created_by,
                "updated_by": created_by,
                "created_at": datetime.now(timezone.utc),
                "updated_at": datetime.now(timezone.utc),
            }
            if static_headers_val is not None:
                create_data["static_headers"] = static_headers_val
            if extra_headers_val is not None:
                create_data["extra_headers"] = extra_headers_val
            if object_permission_id is not None:
                create_data["object_permission_id"] = object_permission_id

            for rate_field in ("tpm_limit", "rpm_limit", "session_tpm_limit", "session_rpm_limit"):
                _val = agent.get(rate_field)
                if _val is not None:
                    create_data[rate_field] = _val

            # Create agent in DB
            created_agent = await prisma_client.db.litellm_agentstable.create(
                data=create_data,
                include={"object_permission": True},
            )

            created_agent_dict = created_agent.model_dump()
            if created_agent.object_permission is not None:
                try:
                    created_agent_dict["object_permission"] = created_agent.object_permission.model_dump()
                except Exception:
                    created_agent_dict["object_permission"] = created_agent.object_permission.dict()
            return AgentResponse(**created_agent_dict)  # type: ignore
        except Exception as e:
            raise Exception(f"Error adding agent to DB: {str(e)}")

    async def delete_agent_from_db(
        self, agent_id: str, prisma_client: PrismaClient
    ) -> Dict[str, Any]:
        """
        Delete an agent from the database
        """
        try:
            deleted_agent = await prisma_client.db.litellm_agentstable.delete(
                where={"agent_id": agent_id}
            )
            return dict(deleted_agent)
        except Exception as e:
            raise Exception(f"Error deleting agent from DB: {str(e)}")

    async def patch_agent_in_db(
        self,
        agent_id: str,
        agent: PatchAgentRequest,
        prisma_client: PrismaClient,
        updated_by: str,
    ) -> AgentResponse:
        """
        Patch an agent in the database.

        Get the existing agent from the database and patch it with the new values.

        Args:
            agent_id: The ID of the agent to patch
            agent: The new agent values to patch
            prisma_client: The Prisma client to use
            updated_by: The user ID of the user who is patching the agent

        Returns:
            The patched agent
        """
        try:

            existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
                where={"agent_id": agent_id}
            )
            if existing_agent is not None:
                existing_agent = dict(existing_agent)

            if existing_agent is None:
                raise Exception(f"Agent with ID {agent_id} not found")

            augment_agent = {**existing_agent, **agent}
            update_data: Dict[str, Any] = {}
            if augment_agent.get("agent_name"):
                update_data["agent_name"] = augment_agent.get("agent_name")
            if augment_agent.get("litellm_params"):
                update_data["litellm_params"] = safe_dumps(
                    augment_agent.get("litellm_params")
                )
            if augment_agent.get("agent_card_params"):
                update_data["agent_card_params"] = safe_dumps(
                    augment_agent.get("agent_card_params")
                )

            for rate_field in ("tpm_limit", "rpm_limit", "session_tpm_limit", "session_rpm_limit"):
                if rate_field in agent:
                    update_data[rate_field] = agent.get(rate_field)
            if "static_headers" in agent:
                headers_value = agent.get("static_headers")
                update_data["static_headers"] = safe_dumps(
                    dict(headers_value) if headers_value is not None else {}
                )
            if "extra_headers" in agent:
                extra_headers_value = agent.get("extra_headers")
                update_data["extra_headers"] = (
                    extra_headers_value if extra_headers_value is not None else []
                )
            if agent.get("object_permission") is not None:
                agent_copy = dict(augment_agent)
                existing_object_permission_id = existing_agent.get(
                    "object_permission_id"
                )
                object_permission_id = (
                    await handle_update_object_permission_common(
                        agent_copy,
                        existing_object_permission_id,
                        prisma_client,
                    )
                )
                if object_permission_id is not None:
                    update_data["object_permission_id"] = object_permission_id
            # Patch agent in DB
            patched_agent = await prisma_client.db.litellm_agentstable.update(
                where={"agent_id": agent_id},
                data={
                    **update_data,
                    "updated_by": updated_by,
                    "updated_at": datetime.now(timezone.utc),
                },
                include={"object_permission": True},
            )
            patched_agent_dict = patched_agent.model_dump()
            if patched_agent.object_permission is not None:
                try:
                    patched_agent_dict["object_permission"] = patched_agent.object_permission.model_dump()
                except Exception:
                    patched_agent_dict["object_permission"] = patched_agent.object_permission.dict()
            return AgentResponse(**patched_agent_dict)  # type: ignore
        except Exception as e:
            raise Exception(f"Error patching agent in DB: {str(e)}")

    async def update_agent_in_db(
        self,
        agent_id: str,
        agent: AgentConfig,
        prisma_client: PrismaClient,
        updated_by: str,
    ) -> AgentResponse:
        """
        Update an agent in the database
        """
        try:
            agent_name = agent.get("agent_name")

            # Serialize litellm_params
            litellm_params_obj: Any = agent.get("litellm_params", {})
            if hasattr(litellm_params_obj, "model_dump"):
                litellm_params_dict = litellm_params_obj.model_dump()
            else:
                litellm_params_dict = (
                    dict(litellm_params_obj) if litellm_params_obj else {}
                )
            litellm_params: str = safe_dumps(litellm_params_dict)

            # Serialize agent_card_params
            agent_card_params_obj: Any = agent.get("agent_card_params", {})
            if hasattr(agent_card_params_obj, "model_dump"):
                agent_card_params_dict = agent_card_params_obj.model_dump()
            else:
                agent_card_params_dict = (
                    dict(agent_card_params_obj) if agent_card_params_obj else {}
                )
            agent_card_params: str = safe_dumps(agent_card_params_dict)

            # Serialize static_headers for update
            static_headers_obj_u = agent.get("static_headers")
            static_headers_val_u: str = (
                safe_dumps(dict(static_headers_obj_u))
                if static_headers_obj_u is not None
                else safe_dumps({})
            )
            extra_headers_val_u: List[str] = agent.get("extra_headers") or []

            update_data: Dict[str, Any] = {
                "agent_name": agent_name,
                "litellm_params": litellm_params,
                "agent_card_params": agent_card_params,
                "static_headers": static_headers_val_u,
                "extra_headers": extra_headers_val_u,
                "updated_by": updated_by,
                "updated_at": datetime.now(timezone.utc),
            }

            for rate_field in ("tpm_limit", "rpm_limit", "session_tpm_limit", "session_rpm_limit"):
                _val = agent.get(rate_field)
                if _val is not None:
                    update_data[rate_field] = _val

            if agent.get("object_permission") is not None:
                existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
                    where={"agent_id": agent_id}
                )
                existing_object_permission_id = (
                    existing_agent.object_permission_id
                    if existing_agent is not None
                    else None
                )
                agent_copy = dict(agent)
                object_permission_id = (
                    await handle_update_object_permission_common(
                        agent_copy,
                        existing_object_permission_id,
                        prisma_client,
                    )
                )
                if object_permission_id is not None:
                    update_data["object_permission_id"] = object_permission_id

            # Update agent in DB
            updated_agent = await prisma_client.db.litellm_agentstable.update(
                where={"agent_id": agent_id},
                data=update_data,
                include={"object_permission": True},
            )

            updated_agent_dict = updated_agent.model_dump()
            if updated_agent.object_permission is not None:
                try:
                    updated_agent_dict["object_permission"] = updated_agent.object_permission.model_dump()
                except Exception:
                    updated_agent_dict["object_permission"] = updated_agent.object_permission.dict()
            return AgentResponse(**updated_agent_dict)  # type: ignore
        except Exception as e:
            raise Exception(f"Error updating agent in DB: {str(e)}")

    @staticmethod
    async def get_all_agents_from_db(
        prisma_client: PrismaClient,
    ) -> List[Dict[str, Any]]:
        """
        Get all agents from the database
        """
        try:
            agents_from_db = await prisma_client.db.litellm_agentstable.find_many(
                order={"created_at": "desc"},
                include={"object_permission": True},
            )

            agents: List[Dict[str, Any]] = []
            for agent in agents_from_db:
                agent_dict = dict(agent)
                # object_permission is eagerly loaded via include above
                if agent.object_permission is not None:
                    try:
                        agent_dict["object_permission"] = agent.object_permission.model_dump()
                    except Exception:
                        agent_dict["object_permission"] = agent.object_permission.dict()
                agents.append(agent_dict)

            return agents
        except Exception as e:
            raise Exception(f"Error getting agents from DB: {str(e)}")

    def get_agent_by_id(
        self,
        agent_id: str,
    ) -> Optional[AgentResponse]:
        """
        Get an agent by its ID from the database
        """
        try:
            for agent in self.agent_list:
                if agent.agent_id == agent_id:
                    return agent

            return None
        except Exception as e:
            raise Exception(f"Error getting agent from DB: {str(e)}")

    def get_agent_by_name(self, agent_name: str) -> Optional[AgentResponse]:
        """
        Get an agent by its name from the database
        """
        try:
            for agent in self.agent_list:
                if agent.agent_name == agent_name:
                    return agent

            return None
        except Exception as e:
            raise Exception(f"Error getting agent from DB: {str(e)}")


global_agent_registry = AgentRegistry()
