# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional, Union

import pytest
import torch

import vllm._custom_ops as ops
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm

DTYPES = [torch.bfloat16, torch.float]
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
VEC_HIDDEN_SIZES = range(1024, 1030)
# Avoid combinatorial explosion with full Cartesian product
NUM_TOKENS_HIDDEN_SIZES = [
    *[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]],
    *[(83, i) for i in [1, 1033, 2048, 5120]],
    *[(2048, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5137]],
    *[(4096, i) for i in [1, 64, 5137]],
]

ADD_RESIDUAL = [False, True]
SCALE_UBS = [True, False]
SEEDS = [0]
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

EPS = 1e-6

## Helpers


def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
    return torch.as_tensor(x, dtype=torch.float32, device='cuda')


def ref_rms_norm(rms_norm_layer: RMSNorm,
                 x: torch.Tensor,
                 residual: Optional[torch.Tensor]) \
        -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    if residual is not None:
        residual = residual.clone()
        out, residual = rms_norm_layer.forward_native(x, residual)
    else:
        out = rms_norm_layer.forward_native(x)

    return out, residual


def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
                                x: torch.Tensor,
                                quant_dtype: torch.dtype,
                                residual: Optional[torch.Tensor],
                                scale_ub: Optional[torch.Tensor]) \
        -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    if scale_ub is not None:
        assert quant_dtype == torch.float8_e4m3fn

    # Norm
    torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual)

    # Quant
    if quant_dtype == torch.float8_e4m3fn:
        torch_out, scales = ops.scaled_fp8_quant(torch_out,
                                                 scale_ub=scale_ub,
                                                 use_per_token_if_dynamic=True)
    else:
        assert quant_dtype == torch.int8
        torch_out, scales = ops.scaled_int8_quant(torch_out)

    return torch_out, scales, residual


def ref_impl(rms_norm_layer: RMSNorm,
             x: torch.Tensor,
             quant_dtype: torch.dtype,
             residual: Optional[torch.Tensor],
             scale_ub: Optional[torch.Tensor]) \
        -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype,
                                       residual, scale_ub)


def ops_dynamic_per_token_quant(weight: torch.Tensor,
                                x: torch.Tensor,
                                quant_dtype: torch.dtype,
                                residual: Optional[torch.Tensor],
                                scale_ub: Optional[torch.Tensor]) \
        -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    if residual is not None:
        residual = residual.clone()
    out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS,
                                                       quant_dtype, scale_ub,
                                                       residual)
    return out, scales, residual


def ops_impl(weight: torch.Tensor,
             x: torch.Tensor,
             quant_dtype: torch.dtype,
             residual: Optional[torch.Tensor],
             scale_ub: Optional[torch.Tensor]) \
        -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual,
                                       scale_ub)


@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_rms_norm(
    num_tokens: int,
    hidden_size: int,
    add_residual: bool,
    scale_ub: bool,
    dtype: torch.dtype,
    quant_dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
    torch.random.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)

    if scale_ub is not None and quant_dtype != torch.float8_e4m3fn:
        # skip
        return

    layer = RMSNorm(hidden_size, EPS).to(dtype=dtype)

    # Make weights
    layer.weight.data.normal_(mean=1.0, std=0.1)

    # Make inputs
    scale = 1 / (hidden_size)
    x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
    residual = torch.randn_like(x) * scale if add_residual else None
    if scale_ub is not None:
        rms_x, _ = ref_rms_norm(layer, x, residual)
        scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device='cuda')

    ref_out, ref_scales, ref_residual = \
        ref_impl(layer, x, quant_dtype, residual, scale_ub)
    ops_out, ops_scales, ops_residual = \
        ops_impl(layer.weight, x, quant_dtype, residual, scale_ub)

    assert ref_out.dtype == quant_dtype
    assert ops_out.dtype == quant_dtype
    assert torch.allclose(ref_scales, ops_scales)
    if quant_dtype == torch.int8:
        # big atol to account for round-off errors.
        assert torch.allclose(ref_out, ops_out, atol=1)
    else:
        assert torch.allclose(ref_out.to(dtype=torch.float32),
                              ops_out.to(dtype=torch.float32))
    if add_residual:
        assert torch.allclose(ref_residual, ops_residual)

    output = torch.empty_like(x, dtype=quant_dtype)
    scales = torch.empty((x.numel() // x.shape[-1], 1),
                         device=x.device,
                         dtype=torch.float32)

    opcheck(torch.ops._C.rms_norm_dynamic_per_token_quant,
            (output, x, layer.weight, scales, 1e-5, scale_ub, residual))
