# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests whether TPU Int8 computation is enabled correctly.

Run `pytest tests/quantization/test_tpu_int8.py`.
"""
import pytest

from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.tpu_int8 import (
    TPUInt8LinearMethod)
from vllm.platforms import current_platform

from ...models.registry import HF_EXAMPLE_MODELS

MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]


@pytest.mark.skipif(not current_platform.is_tpu(),
                    reason="TPU Int8 is only enabled for TPUs.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [10])
@pytest.mark.parametrize(
    "hf_overrides",
    [
        # w8a8 dynamic activation
        {
            'quantization_config': {
                'quant_method': 'tpu_int8',
                'activation_scheme': 'dynamic'
            }
        }
    ])
def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int,
                        hf_overrides: dict, monkeypatch) -> None:
    model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
    model_info.check_transformers_version(on_fail="skip")

    activation_scheme = hf_overrides.get('quantization_config',
                                         {}).get('activation_scheme')
    quantize_activation = activation_scheme == 'dynamic'

    # Allows using apply_model
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
    # Prevent error from re-initializing cache
    monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "")

    prompts = [
        "A robot may not injure a human being",
        "It is only with the heart that one can see rightly;",
        "The greatest glory in living lies not in never falling,",
    ]
    answers = [
        "or, being injured, not kill, except in",
        "without the heart, one can only see wrongly.",
        "but in rising every time we fall. - Nelson"
    ]

    with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:

        def check_model(model):
            for name, module in model.named_modules():
                if not isinstance(module, LinearBase):
                    continue
                quant_method = module.quant_method
                assert isinstance(quant_method, TPUInt8LinearMethod)
                assert quant_method.quantize_activation == quantize_activation

        vllm.apply_model(check_model)
        outputs = vllm.generate_greedy(prompts, max_tokens)
        for (_, output), answer in zip(outputs, answers):
            assert answer in output
