# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import glob
import itertools
import os
import subprocess

import jinja2

FILE_HEAD = """
// auto generated by generate.py
// clang-format off

#include "kernel.h"
#include "marlin_template.h"

namespace MARLIN_NAMESPACE_NAME {
""".strip()

TEMPLATE = ("template __global__ void Marlin<"
            "{{scalar_t}}, "
            "{{w_type_id}}, "
            "{{s_type_id}}, "
            "{{threads}}, "
            "{{thread_m_blocks}}, "
            "{{thread_n_blocks}}, "
            "{{thread_k_blocks}}, "
            "{{'true' if m_block_size_8 else 'false'}}, "
            "{{stages}}, "
            "{{group_blocks}}, "
            "{{'true' if is_zp_float else 'false'}}>"
            "( MARLIN_KERNEL_PARAMS );")

# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = [
    "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
    "vllm::kFE2M1f"
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]

THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
#   = 0 : act order case
#   = -1 : channelwise quantization
#   > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, -1, 1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]


def remove_old_kernels():
    for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
        subprocess.call(["rm", "-f", filename])


def generate_new_kernels():
    for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
        all_template_str_list = []

        for group_blocks, m_blocks, thread_configs in itertools.product(
                GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):

            # act order case only support gptq-int4 and gptq-int8
            if group_blocks == 0 and scalar_type not in [
                    "vllm::kU4B8", "vllm::kU8B128"
            ]:
                continue
            if thread_configs[2] == 256:
                # for small batch (m_blocks == 1), we only need (128, 128, 256)
                # for large batch (m_blocks > 1), we only need (64, 256, 256)
                if m_blocks <= 1 and thread_configs[0] != 128:
                    continue
                if m_blocks > 1 and thread_configs[0] != 64:
                    continue

            # we only support channelwise quantization and group_size == 128
            # for fp8
            if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
                continue
            # nvfp4 only supports group_size == 16
            # mxfp4 only supports group_size == 32
            if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
                continue
            # other quantization methods don't support group_size = 16
            if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
                continue

            k_blocks = thread_configs[0] // 16
            n_blocks = thread_configs[1] // 16
            threads = thread_configs[2]

            c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"

            if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
                s_type = "vllm::kFE4M3fn"
            elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
                s_type = "vllm::kFE8M0fnu"
                if dtype == "fp16":
                    # we cannot safely dequantize e8m0 to fp16, so skip this
                    continue
            elif dtype == "fp16":
                s_type = "vllm::kFloat16"
            elif dtype == "bf16":
                s_type = "vllm::kBFloat16"

            template_str = jinja2.Template(TEMPLATE).render(
                scalar_t=c_dtype,
                w_type_id=scalar_type + ".id()",
                s_type_id=s_type + ".id()",
                threads=threads,
                thread_m_blocks=max(m_blocks, 1),
                thread_n_blocks=n_blocks,
                thread_k_blocks=k_blocks,
                m_block_size_8=m_blocks == 0.5,
                stages="pipe_stages",
                group_blocks=group_blocks,
                is_zp_float=False,
            )

            all_template_str_list.append(template_str)

        file_content = FILE_HEAD + "\n\n"
        file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
        filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"

        with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
            f.write(file_content)


if __name__ == "__main__":
    remove_old_kernels()
    generate_new_kernels()
