# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile

import numpy as np
import pytest
import torch
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr

from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
                                             init_distributed_environment)
from vllm.distributed.tpu_distributed_utils import XlaQKVParallelLinear
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.linear import QKVParallelLinear


@pytest.fixture(autouse=True)
def setup_environment():
    # This is a fake config used for init dist env.
    # QKVParallelLinear needs dist env to be initialized.
    engine_args = EngineArgs(
        model="Qwen/Qwen2-1.5B-Instruct",
        max_model_len=64,
        max_num_batched_tokens=64,
        max_num_seqs=4,
    )

    vllm_config = engine_args.create_engine_config()

    with set_current_vllm_config(vllm_config):
        temp_file = tempfile.mkstemp()[1]
        init_distributed_environment(
            1,
            0,
            local_rank=0,
            distributed_init_method=f"file://{temp_file}",
            backend="gloo")
        ensure_model_parallel_initialized(1, 1)
        yield


MESH = None


def _get_spmd_mesh():
    global MESH
    if MESH is None:
        xr.use_spmd()
        num_devices = xr.global_runtime_device_count()
        mesh_shape = (num_devices, 1)
        device_ids = np.array(range(num_devices))
        MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
    return MESH


@pytest.mark.parametrize("bias", [False, True])
# `xr.use_spmd()` will set a global state, and this state is not reversible.
# Therefore, non-SPMD tests should be run before SPMD tests.
@pytest.mark.parametrize("mesh", [None, _get_spmd_mesh()])
@pytest.mark.parametrize("device", ['cpu', 'xla'])
@torch.no_grad()
def test_xla_qkv_linear(bias, mesh, device):
    torch.manual_seed(123)

    qkv_linear = QKVParallelLinear(
        hidden_size=4096,
        head_size=128,
        total_num_heads=32,
        total_num_kv_heads=8,
        bias=bias,
        params_dtype=torch.bfloat16,
        return_bias=False,
    )

    qkv_linear.weight.data = torch.rand_like(qkv_linear.weight.data) / 10
    if bias:
        qkv_linear.bias.data = torch.rand_like(qkv_linear.bias.data)

    xla_qkv_linear = XlaQKVParallelLinear(qkv_linear, mesh=mesh)

    qkv_linear = qkv_linear.to(device)
    xla_qkv_linear = xla_qkv_linear.to(device)
    input_tensor = torch.rand(10, 4096, dtype=torch.bfloat16) / 10
    input_tensor = input_tensor.to(device)

    output = qkv_linear(input_tensor)
    xla_output = xla_qkv_linear(input_tensor)
    assert torch.allclose(output.cpu(), xla_output.cpu())
