import json
from typing import Any, List, Literal, Optional, Tuple

import litellm
from litellm._logging import verbose_logger
from litellm.types.llms.openai import Batch
from litellm.types.utils import CallTypes, ModelInfo, Usage
from litellm.utils import token_counter


async def calculate_batch_cost_and_usage(
    file_content_dictionary: List[dict],
    custom_llm_provider: Literal["openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"],
    model_name: Optional[str] = None,
    model_info: Optional[ModelInfo] = None,
) -> Tuple[float, Usage, List[str]]:
    """
    Calculate the cost and usage of a batch.

    Args:
        model_info: Optional deployment-level model info with custom batch
            pricing. Threaded through to batch_cost_calculator so that
            deployment-specific pricing (e.g. input_cost_per_token_batches)
            is used instead of the global cost map.
    """
    batch_cost = _batch_cost_calculator(
        custom_llm_provider=custom_llm_provider,
        file_content_dictionary=file_content_dictionary,
        model_name=model_name,
        model_info=model_info,
    )
    batch_usage = _get_batch_job_total_usage_from_file_content(
        file_content_dictionary=file_content_dictionary,
        custom_llm_provider=custom_llm_provider,
        model_name=model_name,
    )
    batch_models = _get_batch_models_from_file_content(file_content_dictionary, model_name)

    return batch_cost, batch_usage, batch_models


async def _handle_completed_batch(
    batch: Batch,
    custom_llm_provider: Literal["openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"],
    model_name: Optional[str] = None,
    litellm_params: Optional[dict] = None,
) -> Tuple[float, Usage, List[str]]:
    """Helper function to process a completed batch and handle logging
    
    Args:
        batch: The batch object
        custom_llm_provider: The LLM provider
        model_name: Optional model name
        litellm_params: Optional litellm parameters containing credentials (api_key, api_base, etc.)
    """
    # Get batch results
    file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
        batch, custom_llm_provider, litellm_params=litellm_params
    )

    # Calculate costs and usage
    batch_cost = _batch_cost_calculator(
        custom_llm_provider=custom_llm_provider,
        file_content_dictionary=file_content_dictionary,
        model_name=model_name,
    )
    batch_usage = _get_batch_job_total_usage_from_file_content(
        file_content_dictionary=file_content_dictionary,
        custom_llm_provider=custom_llm_provider,
        model_name=model_name,
    )

    batch_models = _get_batch_models_from_file_content(file_content_dictionary, model_name)

    return batch_cost, batch_usage, batch_models


def _get_batch_models_from_file_content(
    file_content_dictionary: List[dict],
    model_name: Optional[str] = None,
) -> List[str]:
    """
    Get the models from the file content
    """
    if model_name:
        return [model_name]
    batch_models = []
    for _item in file_content_dictionary:
        if _batch_response_was_successful(_item):
            _response_body = _get_response_from_batch_job_output_file(_item)
            _model = _response_body.get("model")
            if _model:
                batch_models.append(_model)
    return batch_models


def _batch_cost_calculator(
    file_content_dictionary: List[dict],
    custom_llm_provider: Literal["openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"] = "openai",
    model_name: Optional[str] = None,
    model_info: Optional[ModelInfo] = None,
) -> float:
    """
    Calculate the cost of a batch based on the output file id
    """
    # Handle Vertex AI with specialized method
    if custom_llm_provider == "vertex_ai" and model_name:
        batch_cost, _ = calculate_vertex_ai_batch_cost_and_usage(file_content_dictionary, model_name)
        verbose_logger.debug("vertex_ai_total_cost=%s", batch_cost)
        return batch_cost
    
    # For other providers, use the existing logic
    total_cost = _get_batch_job_cost_from_file_content(
        file_content_dictionary=file_content_dictionary,
        custom_llm_provider=custom_llm_provider,
        model_info=model_info,
    )
    verbose_logger.debug("total_cost=%s", total_cost)
    return total_cost


def calculate_vertex_ai_batch_cost_and_usage(
    vertex_ai_batch_responses: List[dict],
    model_name: Optional[str] = None,
) -> Tuple[float, Usage]:
    """
    Calculate both cost and usage from Vertex AI batch responses.

    Vertex AI batch output lines have format:
      {"request": ..., "status": "", "response": {"candidates": [...], "usageMetadata": {...}}}

    usageMetadata contains promptTokenCount, candidatesTokenCount, totalTokenCount.
    """
    from litellm.cost_calculator import batch_cost_calculator

    total_cost = 0.0
    total_tokens = 0
    prompt_tokens = 0
    completion_tokens = 0
    actual_model_name = model_name or "gemini-2.0-flash-001"

    for response in vertex_ai_batch_responses:
        response_body = response.get("response")
        if response_body is None:
            continue

        usage_metadata = response_body.get("usageMetadata", {})
        _prompt = usage_metadata.get("promptTokenCount", 0) or 0
        _completion = usage_metadata.get("candidatesTokenCount", 0) or 0
        _total = usage_metadata.get("totalTokenCount", 0) or (_prompt + _completion)

        line_usage = Usage(
            prompt_tokens=_prompt,
            completion_tokens=_completion,
            total_tokens=_total,
        )

        try:
            p_cost, c_cost = batch_cost_calculator(
                usage=line_usage,
                model=actual_model_name,
                custom_llm_provider="vertex_ai",
            )
            total_cost += p_cost + c_cost
        except Exception as e:
            verbose_logger.debug(
                "vertex_ai batch cost calculation error for line: %s", str(e)
            )

        prompt_tokens += _prompt
        completion_tokens += _completion
        total_tokens += _total

    verbose_logger.info(
        "vertex_ai batch cost: cost=%s, prompt=%d, completion=%d, total=%d",
        total_cost, prompt_tokens, completion_tokens, total_tokens,
    )

    return total_cost, Usage(
        total_tokens=total_tokens,
        prompt_tokens=prompt_tokens,
        completion_tokens=completion_tokens,
    )


async def _get_batch_output_file_content_as_dictionary(
    batch: Batch,
    custom_llm_provider: Literal["openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"] = "openai",
    litellm_params: Optional[dict] = None,
) -> List[dict]:
    """
    Get the batch output file content as a list of dictionaries
    
    Args:
        batch: The batch object
        custom_llm_provider: The LLM provider
        litellm_params: Optional litellm parameters containing credentials (api_key, api_base, etc.)
                       Required for Azure and other providers that need authentication
    """
    from litellm.files.main import afile_content
    from litellm.proxy.openai_files_endpoints.common_utils import \
        _is_base64_encoded_unified_file_id

    if custom_llm_provider == "vertex_ai":
        raise ValueError("Vertex AI does not support file content retrieval")

    if batch.output_file_id is None:
        raise ValueError("Output file id is None cannot retrieve file content")

    file_id = batch.output_file_id
    is_base64_unified_file_id = _is_base64_encoded_unified_file_id(file_id)
    if is_base64_unified_file_id:
        try:
            file_id = is_base64_unified_file_id.split("llm_output_file_id,")[1].split(";")[0]
            verbose_logger.debug(f"Extracted LLM output file ID from unified file ID: {file_id}")
        except (IndexError, AttributeError) as e:
            verbose_logger.error(f"Failed to extract LLM output file ID from unified file ID: {batch.output_file_id}, error: {e}")

    # Build kwargs for afile_content with credentials from litellm_params
    file_content_kwargs = {
        "file_id": file_id,
        "custom_llm_provider": custom_llm_provider,
    }
    
    # Extract and add credentials for file access
    credentials = _extract_file_access_credentials(litellm_params)
    file_content_kwargs.update(credentials)
    
    _file_content = await afile_content(**file_content_kwargs)  # type: ignore[reportArgumentType]
    return _get_file_content_as_dictionary(_file_content.content)


def _extract_file_access_credentials(litellm_params: Optional[dict]) -> dict:
    """
    Extract credentials from litellm_params for file access operations.
    
    This method extracts relevant authentication and configuration parameters
    needed for accessing files across different providers (Azure, Vertex AI, etc.).
    
    Args:
        litellm_params: Dictionary containing litellm parameters with credentials
        
    Returns:
        Dictionary containing only the credentials needed for file access
    """
    credentials = {}
    
    if litellm_params:
        # List of credential keys that should be passed to file operations
        credential_keys = [
            "api_key", "api_base", "api_version", "organization",
            "azure_ad_token", "azure_ad_token_provider",
            "vertex_project", "vertex_location", "vertex_credentials",
            "timeout", "max_retries"
        ]
        for key in credential_keys:
            if key in litellm_params:
                credentials[key] = litellm_params[key]
    
    return credentials


def _get_file_content_as_dictionary(file_content: bytes) -> List[dict]:
    """
    Get the file content as a list of dictionaries from JSON Lines format
    """
    try:
        _file_content_str = file_content.decode("utf-8")
        # Split by newlines and parse each line as a separate JSON object
        json_objects = []
        for line in _file_content_str.strip().split("\n"):
            if line:  # Skip empty lines
                json_objects.append(json.loads(line))
        verbose_logger.debug("json_objects=%s", json.dumps(json_objects, indent=4))
        return json_objects
    except Exception as e:
        raise e


def _get_batch_job_cost_from_file_content(
    file_content_dictionary: List[dict],
    custom_llm_provider: Literal["openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"] = "openai",
    model_info: Optional[ModelInfo] = None,
) -> float:
    """
    Get the cost of a batch job from the file content
    """
    from litellm.cost_calculator import batch_cost_calculator

    try:
        total_cost: float = 0.0
        # parse the file content as json
        verbose_logger.debug(
            "file_content_dictionary=%s", json.dumps(file_content_dictionary, indent=4)
        )
        for _item in file_content_dictionary:
            if _batch_response_was_successful(_item):
                _response_body = _get_response_from_batch_job_output_file(_item)
                if model_info is not None:
                    usage = _get_batch_job_usage_from_response_body(_response_body)
                    model = _response_body.get("model", "")
                    prompt_cost, completion_cost = batch_cost_calculator(
                        usage=usage,
                        model=model,
                        custom_llm_provider=custom_llm_provider,
                        model_info=model_info,
                    )
                    total_cost += prompt_cost + completion_cost
                else:
                    total_cost += litellm.completion_cost(
                        completion_response=_response_body,
                        custom_llm_provider=custom_llm_provider,
                        call_type=CallTypes.aretrieve_batch.value,
                    )
                verbose_logger.debug("total_cost=%s", total_cost)
        return total_cost
    except Exception as e:
        verbose_logger.error("error in _get_batch_job_cost_from_file_content", e)
        raise e


def _get_batch_job_total_usage_from_file_content(
    file_content_dictionary: List[dict],
    custom_llm_provider: Literal["openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"] = "openai",
    model_name: Optional[str] = None,
) -> Usage:
    """
    Get the tokens of a batch job from the file content
    """
    # Handle Vertex AI with specialized method
    if custom_llm_provider == "vertex_ai" and model_name:
        _, batch_usage = calculate_vertex_ai_batch_cost_and_usage(file_content_dictionary, model_name)
        return batch_usage
    
    # For other providers, use the existing logic
    total_tokens: int = 0
    prompt_tokens: int = 0
    completion_tokens: int = 0
    for _item in file_content_dictionary:
        if _batch_response_was_successful(_item):
            _response_body = _get_response_from_batch_job_output_file(_item)
            usage: Usage = _get_batch_job_usage_from_response_body(_response_body)
            total_tokens += usage.total_tokens
            prompt_tokens += usage.prompt_tokens
            completion_tokens += usage.completion_tokens
    return Usage(
        total_tokens=total_tokens,
        prompt_tokens=prompt_tokens,
        completion_tokens=completion_tokens,
    )

def _get_batch_job_input_file_usage(
    file_content_dictionary: List[dict],
    custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
    model_name: Optional[str] = None,
) -> Usage:
    """
    Count the number of tokens in the input file

    Used for batch rate limiting to count the number of tokens in the input file
    """    
    prompt_tokens: int = 0
    completion_tokens: int = 0
    
    for _item in file_content_dictionary:
        body = _item.get("body", {})
        model = body.get("model", model_name or "")
        messages = body.get("messages", [])
        
        if messages:
            item_tokens = token_counter(model=model, messages=messages)
            prompt_tokens += item_tokens
        
    return Usage(
        total_tokens=prompt_tokens + completion_tokens,
        prompt_tokens=prompt_tokens,
        completion_tokens=completion_tokens,
    )

def _get_batch_job_usage_from_response_body(response_body: dict) -> Usage:
    """
    Get the tokens of a batch job from the response body
    """
    _usage_dict = response_body.get("usage", None) or {}
    usage: Usage = Usage(**_usage_dict)
    return usage


def _get_response_from_batch_job_output_file(batch_job_output_file: dict) -> Any:
    """
    Get the response from the batch job output file
    """
    _response: dict = batch_job_output_file.get("response", None) or {}
    _response_body = _response.get("body", None) or {}
    return _response_body


def _batch_response_was_successful(batch_job_output_file: dict) -> bool:
    """
    Check if the batch job response status == 200
    """
    _response: dict = batch_job_output_file.get("response", None) or {}
    return _response.get("status_code", None) == 200