#################################################################################################
#
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################

#
#
# \brief Generates the CUTLASS kernel listing with kernel filtering
#

#

###############################################################################
# Example usage:
# generator.py --operations all --generator-target kernel_listing \
# --architectures "70;75;80" --kernels "*" --disable-cutlass-package-imports
###############################################################################

import collections
import csv
import json
import math
import os

try:
  import builtins
  if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
    raise ImportError("Disabling attempt to import cutlass_library")
  from cutlass_library.library import *
except ImportError:
  from library import *

audit_csv_fields = [
  "KernelType", "KernelName", "Type_A", "Type_B", "Type_C", "Type_Acc", "Type_EpilogueScale", "Type_D", "Type_SFA", "Type_SFD",
  "Layout_A", "Layout_B", "Layout_C", "Layout_D", 
  "Alignment_A", "Alignment_B", "Alignment_C", "Alignment_D",  
  "1SM/2SM", 
  "StreamK Enabled", "Support Runtime_Cluster_Shape", "Support Runtime_Input_Types",
  "Test Counts"
]

audit_csv_runtime_fields = [
  "KerneIndex", "KernelName", 
  "Inst_M", "Inst_N", "Inst_K", "Tile_M", "Tile_N", "Tile_K",
  "Cluster_M", "Cluster_N", "Cluster_K", "Preferred_Cluster_M", "Preferred_Cluster_N", "Preferred_Cluster_K", "Fallback_Cluster_M", "Fallback_Cluster_N", "Fallback_Cluster_K",
  "M", "N", "K", "L", "Alpha_val", "Beta_val",
  "Runtime_Input_Types Enabled", "Runtime_Cluster_Shape Enabled"
]

def hash_cutlass_string(input_string):
  mma_cluster_shape_pattern = r"_\d+x\d+x\d+"         # Matches MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')

  # Remove MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
  output = re.sub(mma_cluster_shape_pattern, "", input_string)

  return output

def transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b):
  # Define a dictionary mapping the detected types to runtime values
  datatype_map = {
    'f4_f4': runtime_datatype_a + '_' + runtime_datatype_b,
    'f4_f6': runtime_datatype_a + '_' + runtime_datatype_b,
    'f4_f8': runtime_datatype_a + '_' + runtime_datatype_b,
    'f6_f4': runtime_datatype_a + '_' + runtime_datatype_b,
    'f6_f6': runtime_datatype_a + '_' + runtime_datatype_b,
    'f6_f8': runtime_datatype_a + '_' + runtime_datatype_b,
    'f8_f4': runtime_datatype_a + '_' + runtime_datatype_b,
    'f8_f6': runtime_datatype_a + '_' + runtime_datatype_b,
    'f8_f8': runtime_datatype_a + '_' + runtime_datatype_b,
    'ue8m0xf4_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
    'ue4m3xf4_ue4m3xf4': 'ue4m3x' + runtime_datatype_a + '_ue4m3x' + runtime_datatype_b,
    'ue8m0xf4_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
    'ue8m0xf4_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
    'ue8m0xf6_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
    'ue8m0xf6_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
    'ue8m0xf8_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
    'ue8m0xf8_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
    'ue8m0xf8_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
  }

  # Regular expression to detect all the keys in datatype_map
  pattern = re.compile(r'(' + '|'.join(map(re.escape, datatype_map.keys())) + r')')

  # Replace detected patterns using the dictionary
  updated_kernel_name = pattern.sub(lambda match: datatype_map[match.group(0)], hashed_kernel_name)

  return updated_kernel_name

# This helper function reports foundational kernel features: datatypes, layouts, alignment and stream-k.
def get_kernel_features(operation, kernel_name,
              dynamic_datatype, runtime_input_datatype):
  numcta_inst = "2sm" if "2sm" in kernel_name else "1sm"
  math_inst = operation.tile_description.math_instruction

  if dynamic_datatype:
      dtype_name_A = runtime_input_datatype[0]
      dtype_name_B = runtime_input_datatype[1]
  else:
      dtype_name_A = DataTypeNames[operation.A.element]
      dtype_name_B = DataTypeNames[operation.B.element]

  layout_name_A = ShortLayoutTypeNames[operation.A.layout]
  layout_name_B = ShortLayoutTypeNames[operation.B.layout]
  layout_name_C = ShortLayoutTypeNames[operation.C.layout]
  layout_name_D = ShortLayoutTypeNames[operation.D.layout]

  scale_factor_D_type = operation.ScaleFactorD.element if hasattr(operation, "ScaleFactorD") else DataType.void
  scale_factor_A_type = getattr(operation, "ScaleFactorA", DataType.void)
  audit_vals = [
          "BlockScaledGEMM" if math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp else "GEMM",
          kernel_name,
          dtype_name_A,
          dtype_name_B,
          DataTypeNames[operation.C.element],
          DataTypeNames[operation.tile_description.math_instruction.element_accumulator],
          DataTypeNames[operation.element_epilogue],
          DataTypeNames[operation.D.element],
          DataTypeNames[scale_factor_D_type],
          DataTypeNames[scale_factor_A_type],
          layout_name_A,
          layout_name_B,
          layout_name_C,
          layout_name_D,
          str(operation.A.alignment),
          str(operation.B.alignment),
          str(operation.C.alignment),
          str(operation.D.alignment),
          numcta_inst,
          "Y" if 'stream_k' in kernel_name else "N",
  ]
  return audit_vals

# This helper function reports other performance-related kernel parameters and those can be specified at runtime: cluster_shape, instruction shap, m/n/k and alpha/beta.
def get_kernel_params(operation, kernel_name, cluster_shape, fallback_cluster_shape, problem_shape, alpha, beta, dynamic_datatype, dynamic_cluster):
  math_inst = operation.tile_description.math_instruction
  audit_vals = [
          str(math_inst.instruction_shape[0]),
          str(math_inst.instruction_shape[1]),
          str(math_inst.instruction_shape[2]),
          str(operation.tile_description.threadblock_shape[0]),
          str(operation.tile_description.threadblock_shape[1]),
          str(operation.tile_description.threadblock_shape[2]),
          str(operation.tile_description.cluster_shape[0]),
          str(operation.tile_description.cluster_shape[1]),
          str(operation.tile_description.cluster_shape[2]),
          str(cluster_shape[0]),
          str(cluster_shape[1]),
          str(cluster_shape[2]),
          str(fallback_cluster_shape[0]),
          str(fallback_cluster_shape[1]),
          str(fallback_cluster_shape[2]),
          str(problem_shape[0]),
          str(problem_shape[1]),
          str(problem_shape[2]),
          str(problem_shape[3]),
          str(alpha),
          str(beta),
          "Y" if dynamic_datatype else "N",
          "Y" if dynamic_cluster else "N",
  ]
  return audit_vals


def _getSubOperationType(kernel):

  if kernel.operation_kind == OperationKind.Gemm:
      return GemmKindNames[kernel.gemm_kind]
  elif kernel.operation_kind == OperationKind.Conv2d:
    return "conv_" + ConvKindNames[kernel.conv_kind]
  elif kernel.operation_kind == OperationKind.Syrk:
    return "syrk_" + SyrkKindNames[kernel.syrk_kind]
  elif kernel.operation_kind == OperationKind.Trmm:
    return "trmm_" + TrmmKindNames[kernel.trmm_kind]
  elif kernel.operation_kind == OperationKind.Symm:
    return "symm_" + SymmKindNames[kernel.symm_kind]
  else:
    raise Exception("Unsupported kernel type")

def _get_inst_shape(math_instruction):
  return "".join(str(x) for x in math_instruction.instruction_shape)

def _is_simt_inst(math_instruction):
  return _get_inst_shape(math_instruction) in ["111","114"]

def _getInstType(input_precision, accumulate_precision, math_instruction):

  # inst_shape
  inst_shape = _get_inst_shape(math_instruction)

  # input precision
  if input_precision == "fp32" and inst_shape != "111":
    inp = "tf32"
  else:
    inp = input_precision

  # Handle SIMT op types first
  if _is_simt_inst(math_instruction):

    simt_input_precision_to_inst = {
      "fp32": "FFMA",
      "fp64": "DFMA",
      "fp16": "HFMA",
      "int8": "IDP4A",
    }
    inst = simt_input_precision_to_inst[input_precision]

  else: # Tensor op instructions

    if accumulate_precision == "cf64":
      fp64_acc_map = {
        MathOperation.multiply_add_complex_gaussian : "gz",
        MathOperation.multiply_add_complex          : "z",
      }
      acc = fp64_acc_map[math_instruction.math_operation]
    else:
      tensor_op_acc_map = {
        "fp32" : "s",
        "cf32" : "s",
        "fp16" : "h",
        "int32": "i",
        "fp64" : "d",
      }
      acc = tensor_op_acc_map[accumulate_precision]

    inst = "{}{}{}".format(acc, inst_shape, inp)

  return inst
# TODO: Computes FLOps/Bytes for GEMM - revisit for conv
def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups=1):
  assert not (batch_count > 1 and num_groups > 1)

  # TODO: adjust for sparsity
  gmem_bytes = (
    (DataTypeSize[operation.A.element] * m // 8) * k +
    (DataTypeSize[operation.B.element] * n // 8) * k +
    (DataTypeSize[operation.C.element] * m // 8) * n
  )

  # TODO: complex-valued support
  flops = 2 * (m * n * k)

  if bool(beta):
    gmem_bytes += (DataTypeSize[operation.C.element] * m // 8) * n
    flops += 2 * m * n

  multiplier = max(batch_count, num_groups)
  gmem_bytes *= multiplier
  flops *= multiplier

  return flops / gmem_bytes

def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
                              ):
  profiler_reference_computing = "--verification-providers=device --providers=cutlass"
  # beta values for L0 and L1
  # TODO: randomize beta values for wider coverage
  beta_values = [0.5]

  is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "120a", "120f"])

  is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch

  if (mode == "functional_L0") and is_supported_arch:
    problem_waves = [0.5, 1.25, 2.5]

    #
    # Dense Gemm
    #

    sm100_mma_data_type_general = [
      'gemm_f16_f16_f16_f16_f16',
      'gemm_f16_f16_f16_void_f16',
      'gemm_f16_f16_f32_f16_f16',
      'tf32gemm_f32_f32_f32_f32_f32',
      'bf16gemm_f32_f32_f32_f32_f32',
    ]

    sm100_mma_data_type_runtime_dtype = [
      'gemm.*f4_f4_f32_f32_f32',
      'gemm.*f6_f6_f32_f32_f32',
      'gemm.*f8_f8_f32_f32_f32',
    ]

    sm100_mma_cluster_size = [
      '8x1x1',
      '4x4x1', '2x1x1',
      '0x0x1' # dynamic cluster
    ]

    # Restrict to two layouts to reduce L0 build and test time.
    sm100_mma_layouts = [ 
      'tnt', 
      'ntn' 
    ]

    # regex list must be in kernel procedural name order
    sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
    sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"

    sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
    sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"

    #
    # Block Scale Gemm
    #

    block_scaled_data_type = [
      # runtime datatypes
      'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
      'gemm.*ue4m3xf4_ue4m3xf4_f32_f16_e5m2',
      'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
      'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
      'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
    ]

    block_scaled_cluster_size = [
      '4x4x1', '2x1x1',
      '0x0x1' # dynamic cluster
    ]

    block_scaled_layouts = ['tnt']
    # regex list must be in kernel procedural name order
    block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
    block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
    
    if arch in ["100a", "100f"]:
      kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
                      f"({sm100_mma_filter_regex_2sm})|" \
                      f"({sm100_mma_filter_regex_1sm_runtime})|" \
                      f"({sm100_mma_filter_regex_2sm_runtime})|" \
                      f"({block_scaled_filter_regex_1sm})|" \
                      f"({block_scaled_filter_regex_2sm})"
    elif arch in ["101a", "101f",
                  ]:
      kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
                      f"({sm100_mma_filter_regex_2sm})|" \
                      f"({sm100_mma_filter_regex_1sm_runtime})|" \
                      f"({sm100_mma_filter_regex_2sm_runtime})|" \
                      f"({block_scaled_filter_regex_1sm})|" \
                      f"({block_scaled_filter_regex_2sm})"
    elif arch in ["120a", "120f"]:

      # blockscaled sm120_mma kernels
      blockscaled_sm120_mma_kernel_cta_tiles = [
        [ '128x128' ]
      ]

      # Restrict to two layouts to reduce L0 build and test time.
      blockscaled_sm120_mma_layouts = [ 'tn' ]
      filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*"
      
      problem_waves = [0.5, 1.25, 2.5]

      kernel_filter = f"({filter_regex_blockscaled_sm120_mma})"
    else:
      error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm120a, sm120f"
      raise Exception(error_message)

  elif mode == "functional_L1":
    sm100_mma_cluster_size = [
                    '0x0x1' # dynamic cluster
                     ]
    # Restrict to two layouts to reduce L1 build and test time.
    sm100_mma_layouts = ['tnt', 'ntn']
    sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
    sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
    block_scaled_data_type = [
      'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
      'ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
      'ue8m0xmx8s26_ue8m0xmx8s26_f32_f16_e5m2',
      'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
      'ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
    ]

    block_scaled_cluster_size = ['4x4x1', '2x1x1', '0x0x1']
    block_scaled_layouts = ['tnt']

    # regex list must be in kernel procedural name order
    block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
    block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
    filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
                          f"({sm100_mma_filter_regex_2sm})|" \
                          f"({block_scaled_filter_regex_1sm})|" \
                          f"({block_scaled_filter_regex_2sm})|" 
    # CTA tiles for sm120 MMA - only run one tile size to reduce build/test times
    sm120_mma_kernel_cta_tiles = [
      # h1688, s1688, i16832, i8816
      [ '256x128' ],
      # d884, c1688,
      [ '128x128' ],
      # c1688, z884
      [ '128x64' ],
      # gz884
      [ '64x64' ]
    ]

    # sm120 MMA instruction shapes, planar complex type excluded as they are not required
    sm120_mma_instruction_shapes = [
      [ 'h1688gemm_(?!planar_complex)',
        's1688gemm_f16',
        's1688gemm_bf16',
        's1688gemm_tf32',
        'i16832gemm',
        'i8816gemm' ],
      [ 'd884gemm', 'c1688tf32gemm' ] ,
      [ 'c1688gemm',
        'z884gemm'  ],
      [ 'gz884gemm']
    ]

    # It's not pretty, but not sure why different instructions support different tile sizes.
    filter_regex_sm120_mma_0 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[0], sm120_mma_kernel_cta_tiles[0]]]) + ").*"
    filter_regex_sm120_mma_1 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[1], sm120_mma_kernel_cta_tiles[1]]]) + ").*"
    filter_regex_sm120_mma_2 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[2], sm120_mma_kernel_cta_tiles[2]]]) + ").*"
    filter_regex_sm120_mma_3 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[3], sm120_mma_kernel_cta_tiles[3]]]) + ").*"

    filter_regex_sm120_mma = f"({filter_regex_sm120_mma_0})|({filter_regex_sm120_mma_1})|({filter_regex_sm120_mma_2})|({filter_regex_sm120_mma_3})"

    problem_waves = [0.5, 1.25, 2.5]

    kernel_filter = f"({filter_regex_sm100_mma})|({filter_regex_sm120_mma})"
  else:
    raise ValueError()

  outfile_name    = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv")

  audit_file_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_SM{arch}_cutlass3x_gemm.csv")

  audit_file_params_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_params_SM{arch}_cutlass3x_gemm.csv")

  kernel_filter_re = re.compile(kernel_filter)
  testcase_counter = 0
  kernels_emitted = 0
  kernels_total = 0

  perf_json_list = []
  kernel_name_set = set()

  testlist_csv_fields = ["testcase", "metadata"]
  testlist_csv_rows = []
  auditlist_csv_map = {}
  auditlist_csv_params_map = {}

  kernel_features = {}

  for cc in manifest.operations[OperationKind.Gemm].keys():
    for kernel_name, operation_l in manifest.operations[OperationKind.Gemm][cc].items():
      assert(len(operation_l) == 1)
      kernels_total += 1
      if len(kernel_filter_re.findall(kernel_name)) == 0:
          continue
      # Only test f16 I/O void C kernels in void C kernel set
      # Exception: Use void C kernels for more accurate perf testing
      if '_void_' in kernel_name and  'perf_' not in mode:
        if 'f16_f16_f16_void_f16' not in kernel_name :
          continue

      kernels_emitted += 1
      kernel_name_set.add(kernel_name)
      hashed_kernel_name = hash_cutlass_string(kernel_name)
      operation = operation_l[0]

      dynamic_cluster = (operation.tile_description.cluster_shape[0] == 0
                          or operation.tile_description.cluster_shape[1] == 0)

      dynamic_datatype = "f8" in kernel_name or "f6" in kernel_name or "f4" in kernel_name

      runtime_input_datatypes = [None]

      if dynamic_datatype:
        if "f4_f4" in kernel_name:
          runtime_input_datatypes = [['e2m1','e2m1']]
        elif "f4_f6" in kernel_name:
          runtime_input_datatypes = [['e2m1','e3m2']]
        elif "f4_f8" in kernel_name:
          runtime_input_datatypes = [['e2m1','e4m3']]

        elif "f6_f4" in kernel_name:
          runtime_input_datatypes = [['e3m2','e2m1']]
        elif "f6_f6" in kernel_name:
          runtime_input_datatypes = [['e3m2','e3m2']]
        elif "f6_f8" in kernel_name:
          runtime_input_datatypes = [['e3m2','e4m3']]

        elif "f8_f4" in kernel_name:
          runtime_input_datatypes = [['e4m3','e2m1']]
        elif "f8_f6" in kernel_name:
          runtime_input_datatypes = [['e4m3','e3m2']]
        elif "f8_f8" in kernel_name:
          runtime_input_datatypes = [
                                    # mask out those not covered in statically encoded test cases
                                    #  ['e5m2','e4m3'],
                                    #  ['e4m3','e5m2'],
                                      ['e4m3','e4m3']
                                    ]

        # block scaled kernels
        elif "ue8m0xf4_ue8m0xf4" in kernel_name:
          runtime_input_datatypes = [['e2m1','e2m1']]
        elif "ue4m3xf4_ue4m3xf4" in kernel_name:
          runtime_input_datatypes = [['e2m1','e2m1']]
        elif "ue8m0xf4_ue8m0xf6" in kernel_name:
          runtime_input_datatypes = [['e2m1','e2m3']]
        elif "ue8m0xf4_ue8m0xf8" in kernel_name:
          runtime_input_datatypes = [['e2m1','e4m3']]

        elif "ue8m0xf6_ue8m0xf4" in kernel_name:
          runtime_input_datatypes = [['e2m3','e2m1']]
        elif "ue8m0xf6_ue8m0xf6" in kernel_name:
          runtime_input_datatypes = [['e2m3','e2m3']]
        elif "ue8m0xf8_ue8m0xf4" in kernel_name:
          runtime_input_datatypes = [['e4m3','e2m1']]

        elif "ue8m0xf8_ue8m0xf4" in kernel_name:
          runtime_input_datatypes = [['e4m3','e2m1']]
        elif "ue8m0xf8_ue8m0xf6" in kernel_name:
          runtime_input_datatypes = [['e4m3','e2m3']]
        elif "ue8m0xf8_ue8m0xf8" in kernel_name:
          runtime_input_datatypes = [['e4m3','e4m3']]

      if dynamic_cluster:
        if mode == "functional_L0":
          runtime_cluster_shapes = [[1,1,1],          [2,1,1], [2,2,1],          [4,1,1],                   [4,4,1]]
        else:
          runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1], [2,4,1], [4,2,1], [4,4,1]]
        cta_tile_shape_m, cta_tile_shape_n, cta_tile_shape_k = operation.tile_description.threadblock_shape
      else:
        runtime_cluster_shapes = [operation.tile_description.cluster_shape]
        cta_tile_shape_m = int(operation.tile_description.threadblock_shape[0] / operation.tile_description.cluster_shape[0])
        cta_tile_shape_n = int(operation.tile_description.threadblock_shape[1] / operation.tile_description.cluster_shape[1])
        cta_tile_shape_k = int(operation.tile_description.threadblock_shape[2] / operation.tile_description.cluster_shape[2])

      alignment_a = operation.A.alignment
      alignment_b = operation.B.alignment
      alignment_c = operation.C.alignment
      alignment_ab_max = max(alignment_a, alignment_b)

      layout3x = operation.layout_name_3x()
      data_types = operation.datatype_name_3x()

      ctas_per_mma_instruction = 1
      if '_2sm' in kernel_name:
        ctas_per_mma_instruction = 2
        valid_cluster_shapes = []

        # Remove any cluster shapes that have cluster_m that is not divisible by 2
        for cs in runtime_cluster_shapes:
          if cs[0] % 2 == 0:
            valid_cluster_shapes.append(cs)
        runtime_cluster_shapes = valid_cluster_shapes

      kernel_problem_waves = problem_waves
      if mode == "functional_L0" or mode == "functional_L1":
        # for functional testing, we want to perturb just a little from even shapes
        # large K = 8 is chosen such that some kernels will warp around their smem buffers, and some will not
        # -16 ensures that we are TMA aligned even for FP8/Int8
        min_k = alignment_ab_max if cta_tile_shape_k == alignment_ab_max else cta_tile_shape_k - alignment_ab_max
        max_k = (cta_tile_shape_k*8) - alignment_ab_max
        problem_shapes_k = [min_k, max_k]
        sm_count = 16
        swizzle_sizes = [0]
        # Larger k and less than half wave trigger streamk +separate reduction case to be generated
        if 'stream_k' in kernel_name:
          problem_shapes_k = [max_k, cta_tile_shape_k*32]
          kernel_problem_waves = [0.125, 1.25, 2.5]
      else:
        raise ValueError

      if "void" in kernel_name:
        beta_values = [0]

      alignment_shift_m = max(alignment_c, alignment_a)
      alignment_shift_n = max(alignment_c, alignment_b)

      is_first_line = True
      for index_waves, waves in enumerate(kernel_problem_waves):
        for index_k, k in enumerate(problem_shapes_k):
          for beta in beta_values:
            for cluster_shape in runtime_cluster_shapes:
              for runtime_input_datatype in runtime_input_datatypes:
                for swizzle_size in swizzle_sizes:
                  grid_size = waves * sm_count
                  cluster_shape_m, cluster_shape_n, cluster_shape_k = tuple(cluster_shape)
                  if cluster_shape_m >= cluster_shape_n:
                    grid_m = cluster_shape_m
                    grid_n = grid_size / grid_m
                    grid_n = max( int((grid_n + cluster_shape_n - 1) / cluster_shape_n) * cluster_shape_n, 1)
                  else:
                    grid_n = cluster_shape_n
                    grid_m = grid_size / grid_n
                    grid_m = max( int((grid_m + cluster_shape_m - 1) / cluster_shape_m) * cluster_shape_m, 1)

                  verification_required = False
                  if mode == "functional_L0" or mode == "functional_L1":
                    if '_void_' not in kernel_name:
                      verification_required = True

                    m = max(int(grid_m * cta_tile_shape_m), alignment_ab_max)
                    n = max(int(grid_n * cta_tile_shape_n), alignment_ab_max)
                    k = int(k)

                    # For functional testing, we want to perturb just a little from even shapes.
                    # Only do this if the perturbation does not cause one of the dimensions of the
                    # problem size to go to zero. This can occur for blockscaling kernels for which
                    # the alignment requirements for A and B can be quite large (e.g., 256).
                    if m > alignment_shift_m:
                      m -= alignment_shift_m
                    if n > alignment_shift_n:
                      n -= alignment_shift_n

                    if '_n32t32_' in kernel_name:
                      continue
                  batch_count = 1
                  if mode == "functional_L0" or mode == "functional_L1" :
                    if index_waves == 0 and index_k == 0 :
                      batch_count = 3 if mode == "functional_L0" else 5
                  gemm_op = "gemm"

                  profiler_reference_computing_override = profiler_reference_computing
                  grouped = is_grouped(manifest.operations_by_name[kernel_name].gemm_kind)
                  num_groups = 1
                  if "bstensorop" in kernel_name:
                    profiler_reference_computing_override = "--mode=trace"
                  if grouped:
                    gemm_op = "grouped_gemm"
                    num_groups = 3 # small to limit test time in host block-scaled reference kernels
                    batch_count = 1
                  elif "bstensorop" in kernel_name:
                    gemm_op = "block_scaled_gemm"
                  elif is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind):
                    gemm_op = "blockwise_gemm"

                  problem_size_category = ['smallK','largeK'][index_k] + '_' + ['beta==0','beta!=0'][bool(beta)]

                  assert m > 0 and n > 0 and k > 0

                  # Emit per-testcase metadata for perf testing usage, eventually in perf database
                  metadata_dict = {
                    "input_params": {
                      'problem_size_category' : problem_size_category,
                      'operation' : _getSubOperationType(operation),
                      'datatype' : data_types,
                      'layout' : layout3x,
                      'm' : m,
                      'n' : n,
                      'k' : k,
                      'beta' : beta,
                      'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta, num_groups)
                    },
                    "runtime_params": {
                      'ctas_per_mma_instruction' : ctas_per_mma_instruction,
                      'tilesize_m' : cta_tile_shape_m,
                      'tilesize_n' : cta_tile_shape_n,
                      'tilesize_k' : cta_tile_shape_k,
                      'cluster_shape_m' : cluster_shape_m,
                      'cluster_shape_n' : cluster_shape_n,
                    }
                  }

                  cluster_m_fallback = ctas_per_mma_instruction if dynamic_cluster else cluster_shape_m
                  cluster_n_fallback = 1 if dynamic_cluster else cluster_shape_n
                  cluster_k_fallback = 1 if dynamic_cluster else cluster_shape_k


                  if dynamic_datatype:
                    runtime_datatype_a, runtime_datatype_b = tuple(runtime_input_datatype)
                    metadata_dict["runtime_params"]["runtime_datatype_a"] = runtime_datatype_a
                    metadata_dict["runtime_params"]["runtime_datatype_b"] = runtime_datatype_b

                  testcase_metadata = [
                    f"cutlass_profiler --operation={gemm_op} {profiler_reference_computing_override} --error-on-no-match --error-if-nothing-is-profiled" +
                    f" --kernels={kernel_name}" +
                    f" --m={str(m)}" +
                    f" --n={str(n)}" +
                    f" --k={str(k)}" +
                    (f" --num_groups={str(num_groups)}" if grouped else "") +
                    f" --cluster_m={str(cluster_shape_m)}" +
                    f" --cluster_n={str(cluster_shape_n)}" +
                    f" --cluster_k={str(cluster_shape_k)}" +
                    f" --cluster_m_fallback={str(cluster_m_fallback)}" +
                    f" --cluster_n_fallback={str(cluster_n_fallback)}" +
                    f" --cluster_k_fallback={str(cluster_k_fallback)}" +
                    f" --beta={str(beta)}" +
                    ("" if grouped else f" --batch_count={str(batch_count)}") +
                    f" --swizzle_size={str(swizzle_size)}" +
                    f" --verification-required={str(verification_required).lower()}"
                  ] \

                  output_dynamic_datatype = dynamic_datatype
                  if output_dynamic_datatype:
                    testcase_metadata[0] += (f" --runtime_input_datatype_a={runtime_datatype_a}" +
                                              f" --runtime_input_datatype_b={runtime_datatype_b}")

                  testcase_metadata.append(json.dumps(metadata_dict))
                  testlist_csv_rows.append(testcase_metadata)
                  testcase_counter += 1

                  alpha = 1.0

                  if dynamic_datatype:
                    hashed_kernel_name = transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b)

                  # If kernel_name is new, initialize its feature set with defaults
                  if hashed_kernel_name not in kernel_features:
                    kernel_features[hashed_kernel_name] = {
                      "is_support_dynamic_cluster": False,
                      "is_support_dynamic_datatype": False,
                    }

                  # Update features for the hashed kernel name
                  kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] |= dynamic_cluster
                  kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] |= dynamic_datatype

                  if hashed_kernel_name not in auditlist_csv_params_map:
                    auditlist_csv_params_map[hashed_kernel_name] = []

                  audit_row_params = get_kernel_params(
                    operation,
                    hashed_kernel_name,
                    (cluster_shape_m, cluster_shape_n, cluster_shape_k),
                    (cluster_m_fallback, cluster_n_fallback, cluster_k_fallback),
                    (m, n, k, batch_count),
                    alpha, beta,
                    dynamic_datatype, dynamic_cluster
                  )

                  auditlist_csv_params_map[hashed_kernel_name].append(audit_row_params)

                  if hashed_kernel_name not in auditlist_csv_map:
                    audit_row = get_kernel_features(operation, hashed_kernel_name, dynamic_datatype, runtime_input_datatype)
                    auditlist_csv_map[hashed_kernel_name] = audit_row

  with open(outfile_name, 'w') as testlist_csv:
    csv_writer = csv.writer(testlist_csv, delimiter=',')
    csv_writer.writerow(testlist_csv_fields)
    csv_writer.writerows(testlist_csv_rows)

  with open(audit_file_name, 'w') as auditlist_csv:
    csv_writer = csv.writer(auditlist_csv, delimiter=',')
    csv_writer.writerow(audit_csv_fields)
    for hashed_kernel_name, row in auditlist_csv_map.items():
      # Append the dynamic features as "Y" or "N"
      dynamic_cluster_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] else "N"
      dynamic_datatype_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] else "N"
      test_count = len(auditlist_csv_params_map[hashed_kernel_name])
      csv_writer.writerow(row + [dynamic_cluster_flag, dynamic_datatype_flag, test_count])

  with open(audit_file_params_name, 'w') as auditlist_csv:
    csv_writer = csv.writer(auditlist_csv, delimiter=',')
    csv_writer.writerow(audit_csv_runtime_fields)
    for kernel_index, (hashed_kernel_name, rows) in enumerate(auditlist_csv_params_map.items(), start=1):
      for i, row in enumerate(rows):
        if i == 0:
          csv_writer.writerow([kernel_index, hashed_kernel_name] + row)
        else:
          csv_writer.writerow(["", ""] + row)

  print(f"Generated a total of {testcase_counter} test cases for {kernels_emitted} kernels out of {kernels_total} total.")

  # Generate a newline separated list of kernel filters
  assert(len(kernel_name_set) == kernels_emitted)
  output_filter_enabled = True
  if output_filter_enabled:
    kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list")
  with open(kernel_filter_outfile_name, "w") as file:
      kernel_name_set = set(map(lambda x: x.replace("_epi_tma", ""), kernel_name_set))
      for kernel_name in kernel_name_set:
          file.write(kernel_name + "\n")

  # Sort L0 and L1 kernel list and csv file to avoid mixing cutlass3.x kernels and sm120_mma kernels in cutlass2.x generated together.
  if mode == "functional_L0" or mode == "functional_L1":
    # Sort the .csv file
    outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv")
    with open(outfile_name) as file:
      data = file.readlines()
      data.sort()
    with open(outfile_name, 'w') as file:
      for i in range(len(data)):
        file.write(data[i])
    # Sort the kernel list
    kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list")
    with open(kernel_filter_outfile_name) as file:
      data = file.readlines()
      data.sort()
    with open(kernel_filter_outfile_name, 'w') as file:
      for i in range(len(data)):
        file.write(data[i])

