# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import asyncio
import random
from typing import Callable

import openai
import pytest
import pytest_asyncio

from tests.utils import RemoteOpenAIServer

MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"


@pytest.fixture(scope="module")
def server():  # noqa: F811
    args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "bfloat16",
        "--max-model-len",
        "8192",
        "--enforce-eager",
        "--max-num-seqs",
        "128",
        "--load-format",
        "dummy",
    ]

    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
        yield remote_server


@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client


@pytest.mark.asyncio
@pytest.mark.parametrize(
    ids=["completion", "chat"],
    argnames=["create_func_gen", "content_body"],
    argvalues=[
        (lambda x: x.completions.create, {
            "prompt": " ".join(['A'] * 10_000)
        }),
        (lambda x: x.chat.completions.create, {
            "messages": [{
                "role": "user",
                "content": " ".join(['A'] * 10_000)
            }]
        }),
    ],
)
async def test_with_and_without_truncate(
    server: RemoteOpenAIServer,
    client: openai.AsyncOpenAI,
    create_func_gen: Callable,
    content_body: dict,
):
    create_func = create_func_gen(client)
    body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}

    num_requests = 10
    truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] *
                              (num_requests - num_requests // 2))
    random.shuffle(truncate_prompt_tokens)

    bodies = [{
        **body, "extra_body": {
            'truncate_prompt_tokens': t
        }
    } for t in truncate_prompt_tokens]

    async def get_status_code(**kwargs):
        try:
            await create_func(**kwargs)
            return 200
        except openai.APIStatusError as e:
            return e.status_code

    responses = await asyncio.gather(*[get_status_code(**b) for b in bodies])
    assert 500 not in responses
