import json
import os
import sys
import traceback

from dotenv import load_dotenv

load_dotenv()
import io
from unittest.mock import AsyncMock, MagicMock, patch

sys.path.insert(
    0, os.path.abspath("../..")
)  # Adds the parent directory to the system path
import pytest
import litellm

import pytest
from litellm.llms.triton.embedding.transformation import TritonEmbeddingConfig
import litellm



def test_split_embedding_by_shape_passes():
    try:
        data = [
            {
                "shape": [2, 3],
                "data": [1, 2, 3, 4, 5, 6],
            }
        ]
        split_output_data = TritonEmbeddingConfig.split_embedding_by_shape(
            data[0]["data"], data[0]["shape"]
        )
        assert split_output_data == [[1, 2, 3], [4, 5, 6]]
    except Exception as e:
        pytest.fail(f"An exception occured: {e}")


def test_split_embedding_by_shape_fails_with_shape_value_error():
    data = [
        {
            "shape": [2],
            "data": [1, 2, 3, 4, 5, 6],
        }
    ]
    with pytest.raises(ValueError):
        TritonEmbeddingConfig.split_embedding_by_shape(
            data[0]["data"], data[0]["shape"]
        )


@pytest.mark.parametrize("stream", [True, False])
def test_completion_triton_generate_api(stream):
    try:
        mock_response = MagicMock()
        if stream:
            def mock_iter_lines():
                mock_output = ''.join([
                    'data: {"model_name":"ensemble","model_version":"1","sequence_end":false,"sequence_id":0,"sequence_start":false,"text_output":"' + t + '"}\n\n'
                    for t in ["I", " am", " an", " AI", " assistant"]
                ])
                for out in mock_output.split('\n'):
                    yield out
            mock_response.iter_lines = mock_iter_lines
        else:
            def return_val():
                return {
                    "text_output": "I am an AI assistant",
                }

            mock_response.json = return_val
        mock_response.status_code = 200

        with patch(
            "litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
            return_value=mock_response,
        ) as mock_post:
            response = litellm.completion(
                model="triton/llama-3-8b-instruct",
                messages=[{"role": "user", "content": "who are u?"}],
                max_tokens=10,
                timeout=5,
                api_base="http://localhost:8000/generate",
                stream=stream,
            )

            # Verify the call was made
            mock_post.assert_called_once()

            # Get the arguments passed to the post request
            print("call args", mock_post.call_args)
            call_kwargs = mock_post.call_args.kwargs  # Access kwargs directly

            # Verify URL
            if stream:
                assert call_kwargs["url"] == "http://localhost:8000/generate_stream"
            else:
                assert call_kwargs["url"] == "http://localhost:8000/generate"

            # Parse the request data from the JSON string
            request_data = json.loads(call_kwargs["data"])

            # Verify request data
            assert request_data["text_input"] == "who are u?"
            assert request_data["parameters"]["max_tokens"] == 10

            # Verify response
            if stream:
                tokens = ["I", " am", " an", " AI", " assistant", None]
                idx = 0
                for chunk in response:
                    assert chunk.choices[0].delta.content == tokens[idx]
                    idx += 1
                assert idx == len(tokens)
            else:
                assert response.choices[0].message.content == "I am an AI assistant"

    except Exception as e:
        print("exception", e)
        import traceback

        traceback.print_exc()
        pytest.fail(f"Error occurred: {e}")


def test_completion_triton_infer_api():
    litellm.set_verbose = True
    try:
        mock_response = MagicMock()

        def return_val():
            return {
                "model_name": "basketgpt",
                "model_version": "2",
                "outputs": [
                    {
                        "name": "text_output",
                        "datatype": "BYTES",
                        "shape": [1],
                        "data": [
                            "0004900005024 0004900006774 0004900005024 0004900005027 0004900005026 0004900005025 0004900005027 0004900005024 0004900006774 0004900005027"
                        ],
                    },
                    {
                        "name": "debug_probs",
                        "datatype": "FP32",
                        "shape": [0],
                        "data": [],
                    },
                    {
                        "name": "debug_tokens",
                        "datatype": "BYTES",
                        "shape": [0],
                        "data": [],
                    },
                ],
            }

        mock_response.json = return_val
        mock_response.status_code = 200

        with patch(
            "litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
            return_value=mock_response,
        ) as mock_post:
            response = litellm.completion(
                model="triton/llama-3-8b-instruct",
                messages=[
                    {
                        "role": "user",
                        "content": "0004900005025 0004900005026 0004900005027",
                    }
                ],
                api_base="http://localhost:8000/infer",
            )

            print("litellm response", response.model_dump_json(indent=4))

            # Verify the call was made
            mock_post.assert_called_once()

            # Get the arguments passed to the post request
            call_kwargs = mock_post.call_args.kwargs

            # Verify URL
            assert call_kwargs["url"] == "http://localhost:8000/infer"

            # Parse the request data from the JSON string
            request_data = json.loads(call_kwargs["data"])

            # Verify request matches expected Triton format
            assert request_data["inputs"][0]["name"] == "text_input"
            assert request_data["inputs"][0]["shape"] == [1]
            assert request_data["inputs"][0]["datatype"] == "BYTES"
            assert request_data["inputs"][0]["data"] == [
                "0004900005025 0004900005026 0004900005027"
            ]

            assert request_data["inputs"][1]["shape"] == [1]
            assert request_data["inputs"][1]["datatype"] == "INT32"
            assert request_data["inputs"][1]["data"] == [20]

            # Verify response format matches expected completion format
            assert (
                response.choices[0].message.content
                == "0004900005024 0004900006774 0004900005024 0004900005027 0004900005026 0004900005025 0004900005027 0004900005024 0004900006774 0004900005027"
            )
            assert response.choices[0].finish_reason == "stop"
            assert response.choices[0].index == 0
            assert response.object == "chat.completion"

    except Exception as e:
        print("exception", e)
        traceback.print_exc()
        pytest.fail(f"Error occurred: {e}")


@pytest.mark.asyncio
async def test_triton_embeddings():
    try:
        litellm.set_verbose = True
        response = await litellm.aembedding(
            model="triton/my-triton-model",
            api_base="https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings",
            input=["good morning from litellm"],
        )
        print(f"response: {response}")

        # stubbed endpoint is setup to return this
        assert response.data[0]["embedding"] == [0.1, 0.2]
    except Exception as e:
        pytest.fail(f"Error occurred: {e}")



def test_triton_generate_raw_request():
    from litellm.utils import return_raw_request
    from litellm.types.utils import CallTypes
    try:
        kwargs = {
            "model": "triton/llama-3-8b-instruct",
            "messages": [{"role": "user", "content": "who are u?"}],
            "api_base": "http://localhost:8000/generate",
        }
        raw_request = return_raw_request(endpoint=CallTypes.completion, kwargs=kwargs)
        print("raw_request", raw_request)
        assert raw_request is not None
        assert "bad_words" not in json.dumps(raw_request["raw_request_body"])
        assert "stop_words" not in json.dumps(raw_request["raw_request_body"])
    except Exception as e:
        pytest.fail(f"Error occurred: {e}")

