# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Verify that seeded random sampling is deterministic.

Run `pytest tests/samplers/test_seeded_generate.py`.
"""
import copy
import random
from itertools import combinations

import pytest

from vllm import SamplingParams
from vllm.model_executor.utils import set_random_seed

MODEL = "facebook/opt-125m"
RANDOM_SEEDS = list(range(5))


@pytest.fixture
def vllm_model(vllm_runner, monkeypatch):
    # This file relies on V0 internals.
    monkeypatch.setenv("VLLM_USE_V1", "0")
    with vllm_runner(MODEL, dtype="half") as vllm_model:
        yield vllm_model


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_random_sample_with_seed(
    vllm_model,
    example_prompts,
    seed: int,
) -> None:
    set_random_seed(seed)

    sampling_params = SamplingParams(
        # Parameters to ensure sufficient randomness
        temperature=3.0,
        top_p=min(random.random() + 0.3, 1),
        top_k=random.randint(5, 20),
        n=random.randint(1, 10),
        presence_penalty=random.randint(0, 1),
        max_tokens=8,
        ignore_eos=True,
    )

    sampling_params_seed_1 = copy.deepcopy(sampling_params)
    sampling_params_seed_1.seed = 100
    sampling_params_seed_2 = copy.deepcopy(sampling_params)
    sampling_params_seed_2.seed = 200

    llm = vllm_model.llm

    for prompt in example_prompts:
        for params in (
                sampling_params,
                sampling_params_seed_1,
                sampling_params_seed_2,
                sampling_params,
                sampling_params_seed_1,
                sampling_params_seed_2,
        ):
            llm._add_request(prompt, params=params)

    results = llm._run_engine(use_tqdm=False)
    all_outputs = [[out.token_ids for out in output.outputs]
                   for output in results]

    for i in range(0, len(example_prompts), 6):
        outputs = all_outputs[i:i + 6]

        # verify all non-seeded requests differ
        for output_a, output_b in combinations(
            (outputs[0], outputs[1], outputs[2], outputs[3]),
                2,
        ):
            assert output_a != output_b

        # verify requests with the same seed match
        assert outputs[1] == outputs[4]
        assert outputs[2] == outputs[5]

        # verify generations within the same parallel sampling group differ
        for output in outputs:
            for sub_output_a, sub_output_b in combinations(output, 2):
                assert sub_output_a != sub_output_b
