# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional

import flashinfer
import pytest
import torch

from vllm.platforms import current_platform

if not current_platform.is_device_capability(100):
    pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
                allow_module_level=True)

FLOAT32_BYTES = torch.finfo(torch.float).bits // 8

# KV Cache Layout for TRT-LLM
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)

MAX_Q_LEN = 1024
MAX_KV_LEN = 4096
BATCH_SIZES = [4, 12]
NUM_HEADS = [(16, 16), (40, 8)]
HEAD_SIZES = [128]
BLOCK_SIZES = [16]
KV_LAYOUTS = ["HND"]
DTYPES = [torch.bfloat16]
KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()]
NUM_BLOCKS = 32768  # Large enough to test overflow in index calculation.
SOFT_CAPS = [None, 50.0]


def to_float8(x, dtype=torch.float8_e4m3fn):
    finfo = torch.finfo(dtype)
    min_val, max_val = x.aminmax()
    amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
    scale = finfo.max / amax * 0.1
    x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
    return x_scl_sat.to(dtype), scale.float().reciprocal()


@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
@torch.inference_mode
def test_flashinfer_trtllm_decode_with_baseline(
    batch_size: int,
    num_heads: tuple[int, int],
    head_size: int,
    block_size: int,
    kv_layout: str,
    dtype: torch.dtype,
    kv_cache_dtype: Optional[torch.dtype],
    soft_cap: Optional[float],
) -> None:
    kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype

    torch.set_default_device("cuda")
    current_platform.seed_everything(0)

    kv_lens = torch.randint(1, MAX_KV_LEN, (batch_size, ), dtype=torch.int32)
    kv_lens[-1] = MAX_KV_LEN
    max_kv_len = torch.max(kv_lens).item()
    num_seqs = len(kv_lens)

    num_query_heads = num_heads[0]
    num_kv_heads = num_heads[1]
    assert num_query_heads % num_kv_heads == 0

    scale = head_size**-0.5

    query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)

    kv_cache_shape = None
    if kv_layout == "NHD":
        kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
    elif kv_layout == "HND":
        kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
    else:
        raise ValueError(f"Invalid kv_layout: {kv_layout}")
    key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
    kv_scale = 1.0
    if kv_cache_dtype is current_platform.fp8_dtype():
        key_value_cache, kv_scale = to_float8(key_value_cache,
                                              current_platform.fp8_dtype())

    max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
    block_tables = torch.randint(0,
                                 NUM_BLOCKS,
                                 (num_seqs, max_num_blocks_per_seq),
                                 dtype=torch.int32)
    k_scale = v_scale = kv_scale
    kv_indptr = [0]
    kv_indices = []
    kv_last_page_lens = []
    for i in range(num_seqs):
        seq_len = kv_lens[i]
        assert seq_len > 0
        num_blocks = (seq_len + block_size - 1) // block_size
        kv_indices.extend(block_tables[i, :num_blocks])
        kv_indptr.append(kv_indptr[-1] + num_blocks)
        kv_last_page_len = seq_len % block_size
        if kv_last_page_len == 0:
            kv_last_page_len = block_size
        kv_last_page_lens.append(kv_last_page_len)

    kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
    kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
    kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)

    workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
    wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
        workspace_buffer,
        kv_layout,
        use_tensor_cores=((num_query_heads // num_kv_heads) > 4))
    wrapper.plan(kv_indptr,
                 kv_indices,
                 kv_last_page_lens,
                 num_query_heads,
                 num_kv_heads,
                 head_size,
                 block_size,
                 "NONE",
                 sm_scale=scale,
                 q_data_type=dtype,
                 kv_data_type=kv_cache_dtype,
                 logits_soft_cap=soft_cap)

    output = torch.empty(query.shape, dtype=dtype)
    wrapper.run(query,
                key_value_cache,
                k_scale=k_scale,
                v_scale=v_scale,
                out=output)

    # TRTLLM Decode
    kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
    output_trtllm = torch.empty(query.shape, dtype=dtype)
    flashinfer.decode.trtllm_batch_decode_with_kv_cache(
        query=query.contiguous(),
        kv_cache=key_value_cache,
        workspace_buffer=workspace_buffer,
        block_tables=block_tables,
        seq_lens=kv_lens_tensor,
        max_seq_len=max_kv_len,
        bmm1_scale=k_scale * scale,
        bmm2_scale=v_scale,
        out=output_trtllm,
    )

    torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
        f"{torch.max(torch.abs(output - output_trtllm))}"


@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("soft_cap", [None])
@torch.inference_mode
def test_flashinfer_trtllm_prefill_with_baseline(
    batch_size: int,
    num_heads: tuple[int, int],
    head_size: int,
    block_size: int,
    kv_layout: str,
    dtype: torch.dtype,
    kv_cache_dtype: Optional[torch.dtype],
    soft_cap: Optional[float],
) -> None:
    kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
    if dtype != kv_cache_dtype:
        pytest.skip(f"Not supported dtype({dtype}) with "
                    "kv_cache_dtype({kv_cache_dtype})")

    torch.set_default_device("cuda")
    current_platform.seed_everything(0)

    q_lens = torch.randint(1, MAX_Q_LEN, (batch_size, ), dtype=torch.int32)
    q_lens[-1] = MAX_Q_LEN
    max_q_len = torch.max(q_lens).item()
    q_indptr = torch.cat([
        torch.tensor([0], dtype=torch.int32),
        torch.cumsum(q_lens, dim=0, dtype=torch.int32),
    ])

    kv_lens = torch.randint(0, MAX_KV_LEN, (batch_size, ), dtype=torch.int32)
    kv_lens[-1] = MAX_KV_LEN

    seq_lens = kv_lens + q_lens
    max_seq_len = torch.max(seq_lens).item()
    num_seqs = len(seq_lens)

    num_query_heads = num_heads[0]
    num_kv_heads = num_heads[1]
    assert num_query_heads % num_kv_heads == 0

    scale = head_size**-0.5

    query = torch.randn(torch.sum(q_lens).item(),
                        num_query_heads,
                        head_size,
                        dtype=dtype)

    kv_cache_shape = None
    if kv_layout == "NHD":
        kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
    elif kv_layout == "HND":
        kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
    else:
        raise ValueError(f"Invalid kv_layout: {kv_layout}")
    key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
    kv_scale = 1.0
    if kv_cache_dtype is current_platform.fp8_dtype():
        key_value_cache, kv_scale = to_float8(key_value_cache,
                                              current_platform.fp8_dtype())

    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
    block_tables = torch.randint(0,
                                 NUM_BLOCKS,
                                 (num_seqs, max_num_blocks_per_seq),
                                 dtype=torch.int32)
    k_scale = v_scale = kv_scale
    kv_indptr = [0]
    kv_indices = []
    kv_last_page_lens = []
    for i in range(num_seqs):
        seq_len = seq_lens[i]
        assert seq_len > 0
        num_blocks = (seq_len + block_size - 1) // block_size
        kv_indices.extend(block_tables[i, :num_blocks])
        kv_indptr.append(kv_indptr[-1] + num_blocks)
        kv_last_page_len = seq_len % block_size
        if kv_last_page_len == 0:
            kv_last_page_len = block_size
        kv_last_page_lens.append(kv_last_page_len)

    kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
    kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
    kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)

    workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
    wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
        workspace_buffer, kv_layout)
    wrapper.plan(q_indptr,
                 kv_indptr,
                 kv_indices,
                 kv_last_page_lens,
                 num_query_heads,
                 num_kv_heads,
                 head_size,
                 block_size,
                 causal=True,
                 sm_scale=scale,
                 q_data_type=dtype,
                 kv_data_type=kv_cache_dtype,
                 logits_soft_cap=soft_cap)

    output = torch.empty(query.shape, dtype=dtype)
    wrapper.run(query,
                key_value_cache,
                k_scale=k_scale,
                v_scale=v_scale,
                out=output)

    # TRTLLM Decode
    output_trtllm = torch.empty(query.shape, dtype=dtype)
    flashinfer.prefill.trtllm_batch_context_with_kv_cache(
        query=query.contiguous(),
        kv_cache=key_value_cache,
        workspace_buffer=workspace_buffer,
        block_tables=block_tables,
        seq_lens=seq_lens,
        max_q_len=max_q_len,
        max_kv_len=max_seq_len,
        bmm1_scale=k_scale * scale,
        bmm2_scale=v_scale,
        batch_size=num_seqs,
        cum_seq_lens_q=q_indptr,
        cum_seq_lens_kv=kv_indptr,
        out=output_trtllm,
    )

    torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
        f"{torch.max(torch.abs(output - output_trtllm))}"
