# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# This is a test for the AITER ops.
# It tests if the AITER ops are
# 1. correctly registered as custom ops
# 2. correctly defined the relationship between
#    implementation and fake function
# 3. can be used with torch.compile
# This file will be skipped if AITER is not installed
# and the platform is not ROCm.

import importlib.util

import pytest
import torch

# this import statement is needed to ensure the ops are registered
import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe  # noqa: F401
from vllm.platforms import current_platform

# need to import once to ensure the ops are registered
# Check if aiter package is installed
aiter_available = importlib.util.find_spec("aiter") is not None

pytestmark = pytest.mark.skipif(
    not (current_platform.is_rocm() and aiter_available),
    reason="AITER ops are only available on ROCm with aiter package installed")


def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
    """Test that the custom op is correctly registered."""
    # Check if the op exists in torch.ops.vllm
    assert hasattr(torch.ops.vllm, 'rocm_aiter_biased_grouped_topk')

    # Check if the op is callable
    assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk)


def test_rocm_aiter_grouped_topk_custom_op_registration():
    """Test that the custom op is correctly registered."""
    # Check if the op exists in torch.ops.vllm
    assert hasattr(torch.ops.vllm, 'rocm_aiter_grouped_topk')

    # Check if the op is callable
    assert callable(torch.ops.vllm.rocm_aiter_grouped_topk)


def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
    """Test that the op can be used with torch.compile."""
    # Create test tensors
    token = 64
    expert = 256
    num_expert_group = 8
    topk = 8
    topk_group = 4
    renormalize = True
    scale_factor = 1.0

    gating_output = torch.randn((token, expert),
                                dtype=torch.bfloat16,
                                device="cuda")
    e_score_correction_bias = torch.randn((expert, ),
                                          dtype=torch.bfloat16,
                                          device="cuda")

    device = gating_output.device
    topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
    topk_weights = torch.empty((token, topk),
                               dtype=torch.float32,
                               device=device)

    # Define a function that uses the op
    def biased_grouped_topk_fn(gating_output, e_score_correction_bias,
                               topk_weights, topk_ids):
        return torch.ops.vllm.rocm_aiter_biased_grouped_topk(
            gating_output, e_score_correction_bias, topk_weights, topk_ids,
            num_expert_group, topk_group, renormalize, scale_factor)

    # Verify the op's fake implementation
    torch.library.opcheck(
        torch.ops.vllm.rocm_aiter_biased_grouped_topk,
        (gating_output, e_score_correction_bias, topk_weights, topk_ids),
        kwargs={
            "num_expert_group": num_expert_group,
            "topk_group": topk_group,
            "need_renorm": renormalize,
            "routed_scaling_factor": scale_factor
        },
        test_utils=("test_faketensor"))

    # Compile the function with appropriate settings
    compiled_fn = torch.compile(biased_grouped_topk_fn,
                                fullgraph=True,
                                backend="inductor",
                                mode="reduce-overhead",
                                dynamic=False)

    topk_weights_original = torch.empty((token, topk),
                                        dtype=torch.float32,
                                        device=device)
    topk_ids_original = torch.empty((token, topk),
                                    dtype=torch.int32,
                                    device=device)

    topk_weights_compiled = torch.empty((token, topk),
                                        dtype=torch.float32,
                                        device=device)
    topk_ids_compiled = torch.empty((token, topk),
                                    dtype=torch.int32,
                                    device=device)

    # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
    biased_grouped_topk_fn(gating_output, e_score_correction_bias,
                           topk_weights_original, topk_ids_original)
    compiled_fn(gating_output, e_score_correction_bias, topk_weights_compiled,
                topk_ids_compiled)

    # Sort the results for comparison since the order might not be deterministic
    topk_ids_original, indices_original = torch.sort(topk_ids_original)
    topk_weights_original = torch.gather(topk_weights_original, 1,
                                         indices_original)

    topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
    topk_weights_compiled = torch.gather(topk_weights_compiled, 1,
                                         indices_compiled)

    # Verify results match
    assert torch.allclose(topk_weights_original,
                          topk_weights_compiled,
                          rtol=1e-2,
                          atol=1e-2)
    assert torch.allclose(topk_ids_original, topk_ids_compiled)


def test_rocm_aiter_grouped_topk_torch_compile_compatibility():
    """Test that the op can be used with torch.compile."""
    # Create test tensors
    token = 64
    expert = 256
    num_expert_group = 8
    topk = 8
    topk_group = 4
    renormalize = True
    scoring_func = "softmax"
    scale_factor = 1.0

    gating_output = torch.randn((token, expert),
                                dtype=torch.bfloat16,
                                device="cuda")

    device = gating_output.device
    topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
    topk_weights = torch.empty((token, topk),
                               dtype=torch.float32,
                               device=device)

    # Define a function that uses the op
    def grouped_topk_fn(gating_output, topk_weights, topk_ids, scoring_func):
        return torch.ops.vllm.rocm_aiter_grouped_topk(
            gating_output, topk_weights, topk_ids, num_expert_group,
            topk_group, renormalize, scoring_func, scale_factor)

    # Verify the op's fake implementation
    torch.library.opcheck(torch.ops.vllm.rocm_aiter_grouped_topk,
                          (gating_output, topk_weights, topk_ids),
                          kwargs={
                              "num_expert_group": num_expert_group,
                              "topk_group": topk_group,
                              "need_renorm": renormalize,
                              "scoring_func": scoring_func,
                              "routed_scaling_factor": scale_factor
                          },
                          test_utils=("test_faketensor"))

    # Compile the function with appropriate settings
    compiled_fn = torch.compile(grouped_topk_fn,
                                fullgraph=True,
                                backend="inductor",
                                mode="reduce-overhead",
                                dynamic=False)

    topk_weights_original = torch.empty((token, topk),
                                        dtype=torch.float32,
                                        device=device)
    topk_ids_original = torch.empty((token, topk),
                                    dtype=torch.int32,
                                    device=device)

    topk_weights_compiled = torch.empty((token, topk),
                                        dtype=torch.float32,
                                        device=device)
    topk_ids_compiled = torch.empty((token, topk),
                                    dtype=torch.int32,
                                    device=device)

    # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
    grouped_topk_fn(gating_output, topk_weights_original, topk_ids_original,
                    scoring_func)
    compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled,
                scoring_func)

    # Sort the results for comparison since the order might not be deterministic
    topk_ids_original, indices_original = torch.sort(topk_ids_original)
    topk_weights_original = torch.gather(topk_weights_original, 1,
                                         indices_original)

    topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
    topk_weights_compiled = torch.gather(topk_weights_compiled, 1,
                                         indices_compiled)

    # Verify results match
    assert torch.allclose(topk_weights_original,
                          topk_weights_compiled,
                          rtol=1e-2,
                          atol=1e-2)
    assert torch.allclose(topk_ids_original, topk_ids_compiled)
