import hashlib
import json
import os
import secrets
from datetime import datetime
from datetime import datetime as dt
from datetime import timezone
from typing import Any, List, Literal, Optional, cast

from pydantic import BaseModel

import litellm
from litellm._logging import verbose_proxy_logger
from litellm.constants import (
    LITELLM_TRUNCATED_PAYLOAD_FIELD,
    LITELLM_TRUNCATION_DB_SAFEGUARD_NOTE,
)
from litellm.constants import \
    MAX_STRING_LENGTH_PROMPT_IN_DB as DEFAULT_MAX_STRING_LENGTH_PROMPT_IN_DB
from litellm.constants import REDACTED_BY_LITELM_STRING
from litellm.litellm_core_utils.core_helpers import (
    get_litellm_metadata_from_kwargs, reconstruct_model_name)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
from litellm.proxy.utils import PrismaClient, hash_token
from litellm.types.utils import (CostBreakdown,
                                 StandardLoggingGuardrailInformation,
                                 StandardLoggingMCPToolCall,
                                 StandardLoggingModelInformation,
                                 StandardLoggingPayload,
                                 StandardLoggingVectorStoreRequest,
                                 VectorStoreSearchResponse)
from litellm.utils import get_end_user_id_for_cost_tracking


def _get_max_string_length_prompt_in_db() -> int:
    """
    Resolve prompt truncation threshold at runtime so values loaded later via
    proxy config environment_variables are honored.
    """
    max_length_str = os.getenv("MAX_STRING_LENGTH_PROMPT_IN_DB")
    if max_length_str is None:
        return DEFAULT_MAX_STRING_LENGTH_PROMPT_IN_DB
    try:
        return int(max_length_str)
    except (TypeError, ValueError):
        return DEFAULT_MAX_STRING_LENGTH_PROMPT_IN_DB


def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
    if _master_key is None:
        return False

    ## string comparison
    is_master_key = secrets.compare_digest(api_key, _master_key)
    if is_master_key:
        return True

    ## hash comparison
    is_master_key = secrets.compare_digest(api_key, hash_token(_master_key))
    if is_master_key:
        return True

    return False


def _get_spend_logs_metadata(
    metadata: Optional[dict],
    applied_guardrails: Optional[List[str]] = None,
    batch_models: Optional[List[str]] = None,
    mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None,
    vector_store_request_metadata: Optional[
        List[StandardLoggingVectorStoreRequest]
    ] = None,
    guardrail_information: Optional[List[StandardLoggingGuardrailInformation]] = None,
    usage_object: Optional[dict] = None,
    model_map_information: Optional[StandardLoggingModelInformation] = None,
    cold_storage_object_key: Optional[str] = None,
    litellm_overhead_time_ms: Optional[float] = None,
    cost_breakdown: Optional[CostBreakdown] = None,
) -> SpendLogsMetadata:
    if metadata is None:
        return SpendLogsMetadata(
            user_api_key=None,
            user_api_key_alias=None,
            user_api_key_team_id=None,
            user_api_key_project_id=None,
            user_api_key_org_id=None,
            user_api_key_user_id=None,
            user_api_key_team_alias=None,
            spend_logs_metadata=None,
            requester_ip_address=None,
            additional_usage_values=None,
            applied_guardrails=None,
            status=None or "success",
            error_information=None,
            proxy_server_request=None,
            batch_models=None,
            mcp_tool_call_metadata=None,
            vector_store_request_metadata=None,
            model_map_information=None,
            usage_object=None,
            guardrail_information=None,
            cold_storage_object_key=cold_storage_object_key,
            litellm_overhead_time_ms=None,
            attempted_retries=None,
            max_retries=None,
            cost_breakdown=None,
        )
    verbose_proxy_logger.debug(
        "getting payload for SpendLogs, available keys in metadata: "
        + str(list(metadata.keys()))
    )

    # Filter the metadata dictionary to include only the specified keys
    clean_metadata = SpendLogsMetadata(
        **{  # type: ignore
            key: metadata.get(key) for key in SpendLogsMetadata.__annotations__.keys()
        }
    )
    clean_metadata["applied_guardrails"] = applied_guardrails
    clean_metadata["batch_models"] = batch_models
    clean_metadata["mcp_tool_call_metadata"] = mcp_tool_call_metadata
    clean_metadata["vector_store_request_metadata"] = (
        _get_vector_store_request_for_spend_logs_payload(vector_store_request_metadata)
    )
    clean_metadata["guardrail_information"] = guardrail_information
    clean_metadata["usage_object"] = usage_object
    clean_metadata["model_map_information"] = model_map_information
    clean_metadata["cold_storage_object_key"] = cold_storage_object_key
    clean_metadata["litellm_overhead_time_ms"] = litellm_overhead_time_ms
    clean_metadata["cost_breakdown"] = cost_breakdown

    return clean_metadata


def generate_hash_from_response(response_obj: Any) -> str:
    """
    Generate a stable hash from a response object.

    Args:
        response_obj: The response object to hash (can be dict, list, etc.)

    Returns:
        A hex string representation of the MD5 hash
    """
    try:
        # Create a stable JSON string of the entire response object
        # Sort keys to ensure consistent ordering
        json_str = json.dumps(response_obj, sort_keys=True)

        # Generate a hash of the response object
        unique_hash = hashlib.md5(json_str.encode()).hexdigest()
        return unique_hash
    except Exception:
        # Return a fallback hash if serialization fails
        return hashlib.md5(str(response_obj).encode()).hexdigest()


def get_spend_logs_id(
    call_type: str, response_obj: dict, kwargs: dict
) -> Optional[str]:
    if call_type == "aretrieve_batch" or call_type == "acreate_file":
        # Generate a hash from the response object
        id: Optional[str] = generate_hash_from_response(response_obj)
    else:
        id = cast(Optional[str], response_obj.get("id")) or cast(
            Optional[str], kwargs.get("litellm_call_id")
        )
    return id


def _extract_usage_for_ocr_call(response_obj: Any, response_obj_dict: dict) -> dict:
    """
    Extract usage information for OCR/AOCR calls.

    OCR responses use usage_info (with pages_processed) instead of token-based usage.

    Args:
        response_obj: The raw response object (can be dict, BaseModel, or other)
        response_obj_dict: Dictionary representation of the response object

    Returns:
        A dict with prompt_tokens=0, completion_tokens=0, total_tokens=0,
        and pages_processed from usage_info.
    """
    usage_info = None

    # Try to extract usage_info from dict
    if isinstance(response_obj_dict, dict) and "usage_info" in response_obj_dict:
        usage_info = response_obj_dict.get("usage_info")

    # Try to extract usage_info from object attributes if not found in dict
    if not usage_info and hasattr(response_obj, "usage_info"):
        usage_info = response_obj.usage_info
        if hasattr(usage_info, "model_dump"):
            usage_info = usage_info.model_dump()
        elif hasattr(usage_info, "__dict__"):
            usage_info = vars(usage_info)

    # For OCR, we track pages instead of tokens
    if usage_info is not None:
        # Handle dict or object with attributes
        if isinstance(usage_info, dict):
            result = {
                "prompt_tokens": 0,  # OCR doesn't use traditional tokens
                "completion_tokens": 0,
                "total_tokens": 0,
            }
            # Add all fields from usage_info, including pages_processed
            for key, value in usage_info.items():
                result[key] = value
            # Ensure pages_processed exists
            if "pages_processed" not in result:
                result["pages_processed"] = 0
            return result
        else:
            return {
                "prompt_tokens": 0,
                "completion_tokens": 0,
                "total_tokens": 0,
                "pages_processed": 0,
            }
    else:
        return {}


def get_logging_payload(  # noqa: PLR0915
    kwargs, response_obj, start_time, end_time
) -> SpendLogsPayload:
    from litellm.proxy.proxy_server import general_settings, master_key

    if kwargs is None:
        kwargs = {}

    if response_obj is None:
        response_obj = {}
    elif not isinstance(response_obj, BaseModel) and not isinstance(response_obj, dict):
        response_obj = {"result": str(response_obj)}
    # standardize this function to be used across, s3, dynamoDB, langfuse logging
    litellm_params = kwargs.get("litellm_params", {})
    metadata = get_litellm_metadata_from_kwargs(kwargs)
    completion_start_time = kwargs.get("completion_start_time", end_time)
    call_type = kwargs.get("call_type")
    cache_hit = kwargs.get("cache_hit", False)

    # Convert response_obj to dict first
    if isinstance(response_obj, dict):
        response_obj_dict = response_obj
    elif isinstance(response_obj, BaseModel):
        response_obj_dict = response_obj.model_dump()
    else:
        response_obj_dict = {}

    # Handle OCR responses which use usage_info instead of usage
    usage: dict = {}
    if call_type in ["ocr", "aocr"]:
        usage = _extract_usage_for_ocr_call(response_obj, response_obj_dict)
    else:
        # Use response_obj_dict instead of response_obj to avoid calling .get() on Pydantic models
        _usage = response_obj_dict.get("usage", None) or {}
        if isinstance(_usage, litellm.Usage):
            usage = dict(_usage)
        elif isinstance(_usage, dict):
            usage = _usage

    id = get_spend_logs_id(call_type or "acompletion", response_obj_dict, kwargs)
    standard_logging_payload = cast(
        Optional[StandardLoggingPayload], kwargs.get("standard_logging_object", None)
    )

    end_user_id = get_end_user_id_for_cost_tracking(litellm_params)

    api_key = metadata.get("user_api_key", "")

    standard_logging_prompt_tokens: int = 0
    standard_logging_completion_tokens: int = 0
    standard_logging_total_tokens: int = 0
    if standard_logging_payload is not None:
        standard_logging_prompt_tokens = standard_logging_payload.get(
            "prompt_tokens", 0
        )
        standard_logging_completion_tokens = standard_logging_payload.get(
            "completion_tokens", 0
        )
        standard_logging_total_tokens = standard_logging_payload.get("total_tokens", 0)
    if api_key is not None and isinstance(api_key, str):
        if api_key.startswith("sk-"):
            # hash the api_key
            api_key = hash_token(api_key)
        if (
            _is_master_key(api_key=api_key, _master_key=master_key)
            and general_settings.get("disable_adding_master_key_hash_to_db") is True
        ):
            api_key = "litellm_proxy_master_key"  # use a known alias, if the user disabled storing master key in db

    if (
        standard_logging_payload is not None
    ):  # [TODO] migrate completely to sl payload. currently missing pass-through endpoint data
        api_key = (
            api_key
            or standard_logging_payload["metadata"].get("user_api_key_hash")
            or ""
        )
        end_user_id = end_user_id or standard_logging_payload["metadata"].get(
            "user_api_key_end_user_id"
        )
    # BUG FIX: Don't overwrite api_key when standard_logging_payload is None
    # The api_key was already extracted from metadata (line 243) and hashed (lines 256-259)
    request_tags = (
        json.dumps(metadata.get("tags", []))
        if isinstance(metadata.get("tags", []), list)
        else "[]"
    )
    if (
        standard_logging_payload is not None
        and standard_logging_payload.get("request_tags") is not None
    ):  # use 'tags' from standard logging payload instead
        request_tags = json.dumps(standard_logging_payload["request_tags"])
    if (
        _is_master_key(api_key=api_key, _master_key=master_key)
        and general_settings.get("disable_adding_master_key_hash_to_db") is True
    ):
        api_key = "litellm_proxy_master_key"  # use a known alias, if the user disabled storing master key in db

    _model_id = metadata.get("model_info", {}).get("id", "")
    _model_group = metadata.get("model_group", "")

    # Extract overhead from hidden_params if available
    litellm_overhead_time_ms = None
    if standard_logging_payload is not None:
        hidden_params = standard_logging_payload.get("hidden_params", {})
        litellm_overhead_time_ms = hidden_params.get("litellm_overhead_time_ms")

    # clean up litellm metadata
    clean_metadata = _get_spend_logs_metadata(
        metadata,
        applied_guardrails=(
            standard_logging_payload["metadata"].get("applied_guardrails", None)
            if standard_logging_payload is not None
            else None
        ),
        batch_models=(
            standard_logging_payload.get("hidden_params", {}).get("batch_models", None)
            if standard_logging_payload is not None
            else None
        ),
        mcp_tool_call_metadata=(
            standard_logging_payload["metadata"].get("mcp_tool_call_metadata", None)
            if standard_logging_payload is not None
            else None
        ),
        vector_store_request_metadata=(
            standard_logging_payload["metadata"].get(
                "vector_store_request_metadata", None
            )
            if standard_logging_payload is not None
            else None
        ),
        usage_object=(
            standard_logging_payload["metadata"].get("usage_object", None)
            if standard_logging_payload is not None
            else None
        ),
        model_map_information=(
            standard_logging_payload["model_map_information"]
            if standard_logging_payload is not None
            else None
        ),
        guardrail_information=(
            standard_logging_payload.get("guardrail_information", None)
            if standard_logging_payload is not None
            else (
                metadata.get("standard_logging_guardrail_information", None)
                if metadata is not None
                else None
            )
        ),
        cold_storage_object_key=(
            standard_logging_payload["metadata"].get("cold_storage_object_key", None)
            if standard_logging_payload is not None
            else None
        ),
        litellm_overhead_time_ms=litellm_overhead_time_ms,
        cost_breakdown=(
            standard_logging_payload.get("cost_breakdown", None)
            if standard_logging_payload is not None
            else None
        ),
    )

    special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
    additional_usage_values = {}
    for k, v in usage.items():
        if k not in special_usage_fields:
            if isinstance(v, BaseModel):
                v = v.model_dump()
            additional_usage_values.update({k: v})
    clean_metadata["additional_usage_values"] = additional_usage_values

    if litellm.cache is not None:
        cache_key = litellm.cache.get_cache_key(**kwargs)
    else:
        cache_key = "Cache OFF"
    if cache_hit is True:
        import time

        id = f"{id}_cache_hit{time.time()}"  # SpendLogs does not allow duplicate request_id

    mcp_namespaced_tool_name = None
    mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = clean_metadata.get(
        "mcp_tool_call_metadata"
    )
    if mcp_tool_call_metadata is not None:
        mcp_namespaced_tool_name = mcp_tool_call_metadata.get(
            "namespaced_tool_name", None
        )

    # Extract agent_id for A2A requests (set directly on model_call_details)
    agent_id: Optional[str] = kwargs.get("agent_id") or metadata.get("agent_id")
    custom_llm_provider = kwargs.get("custom_llm_provider")
    raw_model = cast(str, kwargs.get("model") or "")
    model_name = reconstruct_model_name(raw_model, custom_llm_provider, metadata or {})

    try:
        payload: SpendLogsPayload = SpendLogsPayload(
            request_id=str(id),
            call_type=call_type or "",
            api_key=str(api_key),
            cache_hit=str(cache_hit),
            startTime=_ensure_datetime_utc(start_time),
            endTime=_ensure_datetime_utc(end_time),
            completionStartTime=_ensure_datetime_utc(completion_start_time),
            model=model_name,
            user=metadata.get("user_api_key_user_id", "") or "",
            team_id=metadata.get("user_api_key_team_id", "") or "",
            organization_id=metadata.get("user_api_key_org_id") or "",
            metadata=safe_dumps(clean_metadata),
            cache_key=cache_key,
            spend=kwargs.get("response_cost", 0),
            total_tokens=usage.get("total_tokens", standard_logging_total_tokens),
            prompt_tokens=usage.get("prompt_tokens", standard_logging_prompt_tokens),
            completion_tokens=usage.get(
                "completion_tokens", standard_logging_completion_tokens
            ),
            request_tags=request_tags,
            end_user=end_user_id or "",
            api_base=litellm_params.get("api_base", ""),
            model_group=_model_group,
            model_id=_model_id,
            mcp_namespaced_tool_name=mcp_namespaced_tool_name,
            agent_id=agent_id,
            requester_ip_address=clean_metadata.get("requester_ip_address", None),
            custom_llm_provider=kwargs.get("custom_llm_provider", ""),
            messages=_get_messages_for_spend_logs_payload(
                standard_logging_payload=standard_logging_payload, metadata=metadata
            ),
            response=_get_response_for_spend_logs_payload(
                payload=standard_logging_payload, kwargs=kwargs
            ),
            proxy_server_request=_get_proxy_server_request_for_spend_logs_payload(
                metadata=metadata, litellm_params=litellm_params, kwargs=kwargs
            ),
            session_id=_get_session_id_for_spend_log(
                kwargs=kwargs,
                standard_logging_payload=standard_logging_payload,
            ),
            request_duration_ms=_get_request_duration_ms(start_time, end_time),
            status=_get_status_for_spend_log(
                metadata=metadata,
            ),
        )

        verbose_proxy_logger.debug(
            "SpendTable: created payload - request_id: %s, model: %s, spend: %s",
            payload.get("request_id"),
            payload.get("model"),
            payload.get("spend"),
        )

        # Explicitly clear large intermediate objects to reduce memory pressure
        del response_obj_dict, usage, clean_metadata, additional_usage_values

        return payload
    except Exception as e:
        verbose_proxy_logger.exception(
            "Error creating spendlogs object - {}".format(str(e))
        )
        raise e


def _get_session_id_for_spend_log(
    kwargs: dict,
    standard_logging_payload: Optional[StandardLoggingPayload],
) -> str:
    """
    Get the session id for the spend log.

    This ensures each spend log is associated with a unique session id.

    """
    from litellm._uuid import uuid


    if (
        standard_logging_payload is not None
        and standard_logging_payload.get("trace_id") is not None
    ):
        return str(standard_logging_payload.get("trace_id"))

    # Users can dynamically set the trace_id for each request by passing `litellm_trace_id` in kwargs
    if kwargs.get("litellm_trace_id") is not None:
        return str(kwargs.get("litellm_trace_id"))

    # Ensure we always have a session id, if none is provided
    return str(uuid.uuid4())


def _get_request_duration_ms(start_time: datetime, end_time: datetime) -> Optional[int]:
    """Compute request duration in milliseconds from start and end times."""
    try:
        return int((end_time - start_time).total_seconds() * 1000)
    except Exception:
        return None


def _ensure_datetime_utc(timestamp: datetime) -> datetime:
    """Helper to ensure datetime is in UTC"""
    timestamp = timestamp.astimezone(timezone.utc)
    return timestamp


async def get_spend_by_team_and_customer(
    start_date: dt,
    end_date: dt,
    team_id: str,
    customer_id: str,
    prisma_client: PrismaClient,
):
    sql_query = """
    WITH SpendByModelApiKey AS (
        SELECT
            date_trunc('day', sl."startTime") AS group_by_day,
            COALESCE(tt.team_alias, 'Unassigned Team') AS team_name,
            sl.end_user AS customer,
            sl.model,
            sl.api_key,
            SUM(sl.spend) AS model_api_spend,
            SUM(sl.total_tokens) AS model_api_tokens
        FROM 
            "LiteLLM_SpendLogs" sl
        LEFT JOIN 
            "LiteLLM_TeamTable" tt 
        ON 
            sl.team_id = tt.team_id
        WHERE
            sl."startTime" >= $1::timestamptz AND sl."startTime" < ($2::timestamptz + INTERVAL '1 day')
            AND sl.team_id = $3
            AND sl.end_user = $4
        GROUP BY
            date_trunc('day', sl."startTime"),
            tt.team_alias,
            sl.end_user,
            sl.model,
            sl.api_key
    )
        SELECT
            group_by_day,
            jsonb_agg(jsonb_build_object(
                'team_name', team_name,
                'customer', customer,
                'total_spend', total_spend,
                'metadata', metadata
            )) AS teams_customers
        FROM (
            SELECT
                group_by_day,
                team_name,
                customer,
                SUM(model_api_spend) AS total_spend,
                jsonb_agg(jsonb_build_object(
                    'model', model,
                    'api_key', api_key,
                    'spend', model_api_spend,
                    'total_tokens', model_api_tokens
                )) AS metadata
            FROM 
                SpendByModelApiKey
            GROUP BY
                group_by_day,
                team_name,
                customer
        ) AS aggregated
        GROUP BY
            group_by_day
        ORDER BY
            group_by_day;
    """

    db_response = await prisma_client.db.query_raw(
        sql_query, start_date, end_date, team_id, customer_id
    )
    if db_response is None:
        return []

    return db_response


def _get_messages_for_spend_logs_payload(
    standard_logging_payload: Optional[StandardLoggingPayload],
    metadata: Optional[dict] = None,
) -> str:
    if _should_store_prompts_and_responses_in_spend_logs():
        if standard_logging_payload is not None:
            call_type = standard_logging_payload.get("call_type", "")
            if call_type == "_arealtime":
                messages = standard_logging_payload.get("messages")
                if messages is not None:
                    try:
                        return json.dumps(messages, default=str)
                    except Exception:
                        return "{}"
    return "{}"


def _sanitize_request_body_for_spend_logs_payload(
    request_body: dict,
    visited: Optional[set] = None,
    max_string_length_prompt_in_db: Optional[int] = None,
) -> dict:
    """
    Recursively sanitize request body to prevent logging large base64 strings or other large values.
    Truncates strings longer than MAX_STRING_LENGTH_PROMPT_IN_DB characters and handles nested dictionaries.
    """
    from litellm.constants import (
        LITELLM_TRUNCATED_PAYLOAD_FIELD,
        LITELLM_TRUNCATION_DB_SAFEGUARD_NOTE,
    )

    if visited is None:
        visited = set()
    if max_string_length_prompt_in_db is None:
        max_string_length_prompt_in_db = _get_max_string_length_prompt_in_db()

    # Get the object's memory address to track visited objects
    obj_id = id(request_body)
    if obj_id in visited:
        return {}
    visited.add(obj_id)

    def _sanitize_value(value: Any) -> Any:
        if isinstance(value, dict):
            return _sanitize_request_body_for_spend_logs_payload(
                value, visited, max_string_length_prompt_in_db
            )
        elif isinstance(value, list):
            return [_sanitize_value(item) for item in value]
        elif isinstance(value, str):
            if len(value) > max_string_length_prompt_in_db:
                # Keep 35% from beginning and 65% from end (end is usually more important)
                # This split ensures we keep more context from the end of conversations
                start_ratio = 0.35
                end_ratio = 0.65

                # Calculate character distribution
                start_chars = int(max_string_length_prompt_in_db * start_ratio)
                end_chars = int(max_string_length_prompt_in_db * end_ratio)

                # Ensure we don't exceed the total limit
                total_keep = start_chars + end_chars
                if total_keep > max_string_length_prompt_in_db:
                    end_chars = max_string_length_prompt_in_db - start_chars

                # If the string length is less than what we want to keep, just truncate normally
                if len(value) <= max_string_length_prompt_in_db:
                    return value

                # Calculate how many characters are being skipped
                skipped_chars = len(value) - total_keep

                # Build the truncated string: beginning + truncation marker + end
                truncated_value = (
                    f"{value[:start_chars]}"
                    f"... ({LITELLM_TRUNCATED_PAYLOAD_FIELD} skipped {skipped_chars} chars. "
                    f"{LITELLM_TRUNCATION_DB_SAFEGUARD_NOTE}) ..."
                    f"{value[-end_chars:]}"
                )
                return truncated_value
            return value
        return value

    return {k: _sanitize_value(v) for k, v in request_body.items()}


def _convert_to_json_serializable_dict(
    obj: Any, visited: Optional[set] = None, max_depth: int = 20
) -> Any:
    """
    Convert object to JSON-serializable dict, handling Pydantic models safely.

    This avoids pickle-based deepcopy which fails on Pydantic v2 models
    containing _thread.RLock objects.

    Args:
        obj: Object to convert (dict, list, Pydantic model, or primitive)
        visited: Set of object IDs to track circular references
        max_depth: Maximum recursion depth to prevent infinite recursion

    Returns:
        JSON-serializable version of the object
    """
    if max_depth <= 0:
        # Return a placeholder if max depth is exceeded
        return "<max_depth_exceeded>"

    if visited is None:
        visited = set()

    # Get the object's memory address to track visited objects
    obj_id = id(obj)
    if obj_id in visited:
        # Circular reference detected, return placeholder
        return "<circular_reference>"

    # Only track mutable objects (dict, list, objects with __dict__)
    if isinstance(obj, (dict, list)) or hasattr(obj, "__dict__"):
        visited.add(obj_id)

    try:
        if isinstance(obj, BaseModel):
            # Use Pydantic's model_dump() instead of pickle
            result = obj.model_dump()
            # Recursively process the dumped dict
            return _convert_to_json_serializable_dict(result, visited, max_depth - 1)
        elif isinstance(obj, dict):
            return {
                k: _convert_to_json_serializable_dict(v, visited, max_depth - 1)
                for k, v in obj.items()
            }
        elif isinstance(obj, list):
            return [
                _convert_to_json_serializable_dict(item, visited, max_depth - 1)
                for item in obj
            ]
        elif hasattr(obj, "__dict__"):
            # Handle objects with __dict__ attribute
            return _convert_to_json_serializable_dict(
                obj.__dict__, visited, max_depth - 1
            )
        else:
            # Primitives (str, int, float, bool, None) pass through
            return obj
    finally:
        # Remove from visited set when done processing this object
        if obj_id in visited:
            visited.remove(obj_id)


def _get_proxy_server_request_for_spend_logs_payload(
    metadata: dict,
    litellm_params: dict,
    kwargs: Optional[dict] = None,
) -> str:
    """
    Only store if _should_store_prompts_and_responses_in_spend_logs() is True

    If turn_off_message_logging is enabled, redact messages in the request body.
    """
    if _should_store_prompts_and_responses_in_spend_logs():
        _proxy_server_request = cast(
            Optional[dict], litellm_params.get("proxy_server_request", {})
        )
        if _proxy_server_request is not None:
            _request_body = _proxy_server_request.get("body", {}) or {}

            if kwargs is not None:
                realtime_tools = kwargs.get("realtime_tools")
                if realtime_tools:
                    _request_body = dict(_request_body)
                    _request_body["tools"] = realtime_tools

            # Apply message redaction if turn_off_message_logging is enabled
            if kwargs is not None:
                from litellm.litellm_core_utils.redact_messages import (
                    perform_redaction, should_redact_message_logging)

                # Build model_call_details dict to check redaction settings
                model_call_details = {
                    "litellm_params": litellm_params,
                    "standard_callback_dynamic_params": kwargs.get(
                        "standard_callback_dynamic_params"
                    ),
                }

                # If redaction is enabled, convert to serializable dict before redacting
                if should_redact_message_logging(model_call_details=model_call_details):
                    _request_body = _convert_to_json_serializable_dict(_request_body)
                    perform_redaction(model_call_details=_request_body, result=None)

            _request_body = _sanitize_request_body_for_spend_logs_payload(_request_body)
            _request_body_json_str = json.dumps(_request_body, default=str)
            if LITELLM_TRUNCATED_PAYLOAD_FIELD in _request_body_json_str:
                verbose_proxy_logger.info(
                    "Spend Log: request body was truncated before storing in DB. %s",
                    LITELLM_TRUNCATION_DB_SAFEGUARD_NOTE,
                )
            return _request_body_json_str
    return "{}"


def _get_vector_store_request_for_spend_logs_payload(
    vector_store_request_metadata: Optional[List[StandardLoggingVectorStoreRequest]],
) -> Optional[List[StandardLoggingVectorStoreRequest]]:
    """
    If user does not want to store prompts and responses, then remove the content from the vector store request metadata
    """
    if _should_store_prompts_and_responses_in_spend_logs():
        return vector_store_request_metadata

    # if user does not want to store prompts and responses, then remove the content from the vector store request metadata
    if vector_store_request_metadata is None:
        return None
    for vector_store_request in vector_store_request_metadata:
        vector_store_search_response: VectorStoreSearchResponse = (
            vector_store_request.get("vector_store_search_response")
            or VectorStoreSearchResponse()
        )
        response_data = vector_store_search_response.get("data", []) or []
        for response_item in response_data:
            for content_item in response_item.get("content", []) or []:
                if "text" in content_item:
                    content_item["text"] = REDACTED_BY_LITELM_STRING
    return vector_store_request_metadata


def _get_response_for_spend_logs_payload(
    payload: Optional[StandardLoggingPayload],
    kwargs: Optional[dict] = None,
) -> str:
    if payload is None:
        return "{}"
    if _should_store_prompts_and_responses_in_spend_logs():
        response_obj: Any = payload.get("response")
        if response_obj is None:
            return "{}"

        if kwargs is not None:
            realtime_tool_calls = kwargs.get("realtime_tool_calls")
            if realtime_tool_calls and isinstance(response_obj, dict):
                response_obj = dict(response_obj)
                response_obj["tool_calls"] = realtime_tool_calls

        # Apply message redaction if turn_off_message_logging is enabled
        if kwargs is not None:
            from litellm.litellm_core_utils.redact_messages import (
                perform_redaction, should_redact_message_logging)

            litellm_params = kwargs.get("litellm_params", {})
            model_call_details = {
                "litellm_params": litellm_params,
                "standard_callback_dynamic_params": kwargs.get(
                    "standard_callback_dynamic_params"
                ),
            }

            # If redaction is enabled, convert to serializable dict before redacting
            if should_redact_message_logging(model_call_details=model_call_details):
                response_obj = _convert_to_json_serializable_dict(response_obj)
                response_obj = perform_redaction(
                    model_call_details={}, result=response_obj
                )

        sanitized_wrapper = _sanitize_request_body_for_spend_logs_payload(
            {"response": response_obj}
        )

        sanitized_response = sanitized_wrapper.get("response", response_obj)

        if sanitized_response is None:
            return "{}"
        if isinstance(sanitized_response, str):
            result_str = sanitized_response
        else:
            result_str = safe_dumps(sanitized_response)
        if LITELLM_TRUNCATED_PAYLOAD_FIELD in result_str:
            verbose_proxy_logger.info(
                "Spend Log: response was truncated before storing in DB. %s",
                LITELLM_TRUNCATION_DB_SAFEGUARD_NOTE,
            )
        return result_str
    return "{}"


def _should_store_prompts_and_responses_in_spend_logs() -> bool:
    from litellm.proxy.proxy_server import general_settings
    from litellm.secret_managers.main import get_secret_bool

    # Check general_settings (from DB or proxy_config.yaml)
    store_prompts_value = general_settings.get("store_prompts_in_spend_logs")

    # Normalize case: handle True/true/TRUE, False/false/FALSE, None/null
    if store_prompts_value is True:
        return True
    elif isinstance(store_prompts_value, str):
        # Case-insensitive string comparison
        if store_prompts_value.lower() == "true":
            return True

    # Also check environment variable
    return get_secret_bool("STORE_PROMPTS_IN_SPEND_LOGS") is True


def _get_status_for_spend_log(
    metadata: dict,
) -> Literal["success", "failure"]:
    """
    Get the status for the spend log.

    It's only a failure if metadata.get("status") is "failure"
    """
    _status: Optional[str] = metadata.get("status", None)
    if _status == "failure":
        return "failure"
    return "success"
