#### What this does ####
# This file contains the LiteralAILogger class which is used to log steps to the LiteralAI observability platform.
import asyncio
import os
from litellm._uuid import uuid
from typing import List, Optional

import httpx

from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.llms.custom_httpx.http_handler import (
    HTTPHandler,
    get_async_httpx_client,
    httpxSpecialProvider,
)
from litellm.types.utils import StandardLoggingPayload


class LiteralAILogger(CustomBatchLogger):
    def __init__(
        self,
        literalai_api_key=None,
        literalai_api_url="https://cloud.getliteral.ai",
        env=None,
        **kwargs,
    ):
        self.literalai_api_url = os.getenv("LITERAL_API_URL") or literalai_api_url
        self.headers = {
            "Content-Type": "application/json",
            "x-api-key": literalai_api_key or os.getenv("LITERAL_API_KEY"),
            "x-client-name": "litellm",
        }
        if env:
            self.headers["x-env"] = env
        self.async_httpx_client = get_async_httpx_client(
            llm_provider=httpxSpecialProvider.LoggingCallback
        )
        self.sync_http_handler = HTTPHandler()
        batch_size = os.getenv("LITERAL_BATCH_SIZE", None)
        self.flush_lock = asyncio.Lock()
        super().__init__(
            **kwargs,
            flush_lock=self.flush_lock,
            batch_size=int(batch_size) if batch_size else None,
        )

    def log_success_event(self, kwargs, response_obj, start_time, end_time):
        try:
            verbose_logger.debug(
                "Literal AI Layer Logging - kwargs: %s, response_obj: %s",
                kwargs,
                response_obj,
            )
            data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
            self.log_queue.append(data)
            verbose_logger.debug(
                "Literal AI logging: queue length %s, batch size %s",
                len(self.log_queue),
                self.batch_size,
            )
            if len(self.log_queue) >= self.batch_size:
                self._send_batch()
        except Exception:
            verbose_logger.exception(
                "Literal AI Layer Error - error logging success event."
            )

    def log_failure_event(self, kwargs, response_obj, start_time, end_time):
        verbose_logger.info("Literal AI Failure Event Logging!")
        try:
            data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
            self.log_queue.append(data)
            verbose_logger.debug(
                "Literal AI logging: queue length %s, batch size %s",
                len(self.log_queue),
                self.batch_size,
            )
            if len(self.log_queue) >= self.batch_size:
                self._send_batch()
        except Exception:
            verbose_logger.exception(
                "Literal AI Layer Error - error logging failure event."
            )

    def _send_batch(self):
        if not self.log_queue:
            return

        url = f"{self.literalai_api_url}/api/graphql"
        query = self._steps_query_builder(self.log_queue)
        variables = self._steps_variables_builder(self.log_queue)
        try:
            response = self.sync_http_handler.post(
                url=url,
                json={
                    "query": query,
                    "variables": variables,
                },
                headers=self.headers,
            )

            if response.status_code >= 300:
                verbose_logger.error(
                    f"Literal AI Error: {response.status_code} - {response.text}"
                )
            else:
                verbose_logger.debug(
                    f"Batch of {len(self.log_queue)} runs successfully created"
                )
        except Exception:
            verbose_logger.exception("Literal AI Layer Error")

    async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
        try:
            verbose_logger.debug(
                "Literal AI Async Layer Logging - kwargs: %s, response_obj: %s",
                kwargs,
                response_obj,
            )
            data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
            self.log_queue.append(data)
            verbose_logger.debug(
                "Literal AI logging: queue length %s, batch size %s",
                len(self.log_queue),
                self.batch_size,
            )
            if len(self.log_queue) >= self.batch_size:
                await self.flush_queue()
        except Exception:
            verbose_logger.exception(
                "Literal AI Layer Error - error logging async success event."
            )

    async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
        verbose_logger.info("Literal AI Failure Event Logging!")
        try:
            data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
            self.log_queue.append(data)
            verbose_logger.debug(
                "Literal AI logging: queue length %s, batch size %s",
                len(self.log_queue),
                self.batch_size,
            )
            if len(self.log_queue) >= self.batch_size:
                await self.flush_queue()
        except Exception:
            verbose_logger.exception(
                "Literal AI Layer Error - error logging async failure event."
            )

    async def async_send_batch(self):
        if not self.log_queue:
            return

        url = f"{self.literalai_api_url}/api/graphql"
        query = self._steps_query_builder(self.log_queue)
        variables = self._steps_variables_builder(self.log_queue)

        try:
            response = await self.async_httpx_client.post(
                url=url,
                json={
                    "query": query,
                    "variables": variables,
                },
                headers=self.headers,
            )
            if response.status_code >= 300:
                verbose_logger.error(
                    f"Literal AI Error: {response.status_code} - {response.text}"
                )
            else:
                verbose_logger.debug(
                    f"Batch of {len(self.log_queue)} runs successfully created"
                )
        except httpx.HTTPStatusError as e:
            verbose_logger.exception(
                f"Literal AI HTTP Error: {e.response.status_code} - {e.response.text}"
            )
        except Exception:
            verbose_logger.exception("Literal AI Layer Error")

    def _prepare_log_data(self, kwargs, response_obj, start_time, end_time) -> dict:
        logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
            "standard_logging_object", None
        )

        if logging_payload is None:
            raise ValueError("standard_logging_object not found in kwargs")
        clean_metadata = logging_payload["metadata"]
        metadata = kwargs.get("litellm_params", {}).get("metadata", {})

        settings = logging_payload["model_parameters"]
        messages = logging_payload["messages"]
        response = logging_payload["response"]
        choices: List = []
        if isinstance(response, dict) and "choices" in response:
            choices = response["choices"]
        message_completion = choices[0]["message"] if choices else None
        prompt_id = None
        variables = None

        if messages and isinstance(messages, list) and isinstance(messages[0], dict):
            for message in messages:
                if literal_prompt := getattr(message, "__literal_prompt__", None):
                    prompt_id = literal_prompt.get("prompt_id")
                    variables = literal_prompt.get("variables")
                    message["uuid"] = literal_prompt.get("uuid")
                    message["templated"] = True

        tools = settings.pop("tools", None)

        step = {
            "id": metadata.get("step_id", str(uuid.uuid4())),
            "error": logging_payload["error_str"],
            "name": kwargs.get("model", ""),
            "threadId": metadata.get("literalai_thread_id", None),
            "parentId": metadata.get("literalai_parent_id", None),
            "rootRunId": metadata.get("literalai_root_run_id", None),
            "input": None,
            "output": None,
            "type": "llm",
            "tags": metadata.get("tags", metadata.get("literalai_tags", None)),
            "startTime": str(start_time),
            "endTime": str(end_time),
            "metadata": clean_metadata,
            "generation": {
                "inputTokenCount": logging_payload["prompt_tokens"],
                "outputTokenCount": logging_payload["completion_tokens"],
                "tokenCount": logging_payload["total_tokens"],
                "promptId": prompt_id,
                "variables": variables,
                "provider": kwargs.get("custom_llm_provider", "litellm"),
                "model": kwargs.get("model", ""),
                "duration": (end_time - start_time).total_seconds(),
                "settings": settings,
                "messages": messages,
                "messageCompletion": message_completion,
                "tools": tools,
            },
        }
        return step

    def _steps_query_variables_builder(self, steps):
        generated = ""
        for id in range(len(steps)):
            generated += f"""$id_{id}: String!
            $threadId_{id}: String
            $rootRunId_{id}: String
            $type_{id}: StepType
            $startTime_{id}: DateTime
            $endTime_{id}: DateTime
            $error_{id}: String
            $input_{id}: Json
            $output_{id}: Json
            $metadata_{id}: Json
            $parentId_{id}: String
            $name_{id}: String
            $tags_{id}: [String!]
            $generation_{id}: GenerationPayloadInput
            $scores_{id}: [ScorePayloadInput!]
            $attachments_{id}: [AttachmentPayloadInput!]
            """
        return generated

    def _steps_ingest_steps_builder(self, steps):
        generated = ""
        for id in range(len(steps)):
            generated += f"""
        step{id}: ingestStep(
            id: $id_{id}
            threadId: $threadId_{id}
            rootRunId: $rootRunId_{id}
            startTime: $startTime_{id}
            endTime: $endTime_{id}
            type: $type_{id}
            error: $error_{id}
            input: $input_{id}
            output: $output_{id}
            metadata: $metadata_{id}
            parentId: $parentId_{id}
            name: $name_{id}
            tags: $tags_{id}
            generation: $generation_{id}
            scores: $scores_{id}
            attachments: $attachments_{id}
        ) {{
            ok
            message
        }}
    """
        return generated

    def _steps_query_builder(self, steps):
        return f"""
        mutation AddStep({self._steps_query_variables_builder(steps)}) {{
        {self._steps_ingest_steps_builder(steps)}
        }}
        """

    def _steps_variables_builder(self, steps):
        def serialize_step(event, id):
            result = {}

            for key, value in event.items():
                # Only keep the keys that are not None to avoid overriding existing values
                if value is not None:
                    result[f"{key}_{id}"] = value

            return result

        variables = {}
        for i in range(len(steps)):
            step = steps[i]
            variables.update(serialize_step(step, i))
        return variables
