# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest

from tests.conftest import VllmRunner
from tests.core.utils import create_dummy_prompt
from vllm.engine.llm_engine import LLMEngine
from vllm.sequence import SequenceGroup

MODEL = "JackFram/llama-160m"


def add_seq_group_to_engine(engine: LLMEngine, seq_group: SequenceGroup):
    scheduler = engine.scheduler[0]
    scheduler.add_seq_group(seq_group)


@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
@pytest.mark.parametrize("enforce_eager", [False, True])
def test_num_computed_tokens_update(enable_chunked_prefill: bool,
                                    enforce_eager: bool):

    # Make a vllm engine
    runner = VllmRunner(model_name=MODEL,
                        gpu_memory_utilization=0.7,
                        enable_chunked_prefill=enable_chunked_prefill,
                        enforce_eager=enforce_eager)
    engine: LLMEngine = runner.llm.llm_engine

    num_prompt_steps = 1

    num_output_tokens_list = [4, 8, 12, 15, 16, 17]

    # Create sequence and add to engine
    prompt_len = 10

    for req_idx, num_output_tokens in enumerate(num_output_tokens_list):
        seq, seq_group = create_dummy_prompt(request_id=str(req_idx),
                                             prompt_length=prompt_len,
                                             min_tokens=num_output_tokens,
                                             max_tokens=num_output_tokens)
        add_seq_group_to_engine(engine, seq_group)

        assert seq.data.get_num_computed_tokens() == 0

        for _ in range(num_prompt_steps):
            # prompt steps
            engine.step()

        if not seq.is_finished():
            prompt_num_computed_tokens = seq.data.get_num_computed_tokens()
            # Test correctness of num_computed_tokens after the prompt steps
            assert prompt_num_computed_tokens == \
                        prompt_len + num_prompt_steps - 1

            decode_step_counter = 0
            while not seq.is_finished():
                # Test correctness of num_computed_tokens after the decode steps
                assert seq.data.get_num_computed_tokens(
                ) == prompt_num_computed_tokens + decode_step_counter
                engine.step()
                decode_step_counter += 1

        # Test correctness of num_computed_tokens after the sequence finish.
        assert seq.data.get_num_computed_tokens(
        ) == prompt_len + num_output_tokens - 1
