#!/usr/bin/env python
# Copyright 2019 Google LLC
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import codecs
import collections
import os
import re
import sys
import zlib
import yaml

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from primes import next_prime
import xngen
import xnncommon

parser = argparse.ArgumentParser(description="XNNPACK generator")
parser.add_argument(
    "-s", "--spec", metavar="FILE", required=True, help="Spec (YAML) file")
parser.add_argument(
    "-o",
    "--output-test",
    action="append",
    metavar="FILE",
    required=True,
    help="Test output (C++ source) file(s)")
parser.add_argument(
    "-b",
    "--output-bench",
    metavar="FILE",
    required=False,
    help="Benchmark output (C++ source) file(s)")
parser.set_defaults(defines=list())

def split_ukernel_name(name):
  common_name, target_name = name.split("__", 1)
  common_parts = common_name.split("_")
  param_spec = common_parts[-1]
  if "s" in param_spec:
    param_spec, sr = param_spec.split("s", 1)
    sr = int(sr)
  else:
    sr = 1
  if "c" in param_spec:
    param_spec, kr = param_spec.split("c", 1)
    kr = int(kr)
  else:
    kr = 1
  if "v" in param_spec:
    vector_tile = True
    param_spec, _ = param_spec.split("v", 1)
  else:
    vector_tile = False
  mr, nr = map(int, param_spec.split("x"))
  arch, isa, assembly = xnncommon.parse_target_name(target_name)
  mr_packed = re.search(r"mstep([0-9]+)", target_name)
  if mr_packed:
    mr_packed = mr // int(mr_packed.group(1))
  else:
    mr_packed = mr

  requantization = common_parts[-3]
  if requantization not in ["fp32", "rndnu", "rndnu16"]:
    requantization = None
  return mr, nr, kr, sr, mr_packed, vector_tile, requantization, arch, isa, assembly

GEMM_BENCH_CODE = """\
$if CPP_CHECK:
  #if ${CPP_CHECK}
static void ${UKERNEL_NAME}(benchmark::State& state, const char* net) {
  GEMMBenchmark(state,
    ${GEMM},
    $if INIT_PARAMS is not None:
      ${INIT_PARAMS},
    $if PACK_FN is not None:
      ${PACK_FN},
    $if PACKED_STRIDE_FN is not None:
      ${PACKED_STRIDE_FN},
    /*mr=*/${MR}, /*nr=*/${NR}${NR_SCALE}, /*kr=*/${KR}, /*sr=*/${SR},
    $if DATATYPE in ('qp8',):
      /*mr_packed=*/${MR_PACKED},
    $if ISA_CHECK:
      benchmark::utils::${ISA_CHECK});
    $else:
      /*isa_check=*/nullptr);
}\n
$if KERNELTYPE in ['qb4w']:
  BENCHMARK_GEMM_BL(${UKERNEL_NAME})
$else:
  BENCHMARK_GEMM(${UKERNEL_NAME})
$if CPP_CHECK:
  #endif  // ${CPP_CHECK}
"""

GEMM_CREATE_TESTS_CODE = """\
std::vector<GemmTestParams> CreateTests(
    size_t k_block, size_t adj_k_block,
    size_t mr, size_t nr, size_t kr, size_t sr,
    $if DATATYPE in ('qp8'):
      size_t mr_packed,
    bool is_igemm,
    bool unsigned_inputs,
    std::function<void(GemmMicrokernelTester& tester)> test_func,
    std::function<void()> isa_check = nullptr) {
  std::string kbs = std::to_string(k_block);
  std::string kb2s = std::to_string(k_block * 2);
  std::string akbs = std::to_string(adj_k_block);
  $if NR_SCALE != "":
    nr = nr${NR_SCALE};
  std::string nrs = std::to_string(nr);

  $if DATATYPE in ('qp8',):
    const GemmMicrokernelTester tester = GemmMicrokernelTester()
        .mr(mr).nr(nr).kr(kr).sr(sr).mr_packed(mr_packed).unsigned_inputs(unsigned_inputs);
  $else:
    const GemmMicrokernelTester tester = GemmMicrokernelTester()
        .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs);

  std::vector<GemmTestParams> gemm_tests;
  gemm_tests.reserve(42);

  gemm_tests.push_back(GemmTestParams(
      "k_eq_" + kbs,
      tester.clone()
          .m(mr).n(nr).k(k_block)
          $if KERNELTYPE in ['qb4w', 'qc4w']:
            .b_zero_point(8)
          $if KERNELTYPE in ['qb4w']:
            .bl(32)
      , test_func, isa_check));
  $if DATATYPE != "qp8":
    gemm_tests.push_back(GemmTestParams(
        "strided_cn",
        tester.clone()
            .m(mr).n(nr).k(k_block)
            .cn_stride(xnnpack::NextPrime(nr + 1))
            $if KERNELTYPE in ['qb4w', 'qc4w']:
              .b_zero_point(8)
            $if KERNELTYPE in ['qb4w']:
              .bl(32)
      , test_func, isa_check));
  if (!is_igemm) {
    gemm_tests.push_back(GemmTestParams(
        "k_eq_" + kbs + "_strided_a",
        tester.clone()
            .m(mr).n(nr).k(k_block)
            .a_stride(xnnpack::NextPrime(k_block + 1))
            $if KERNELTYPE in ['qb4w', 'qc4w']:
              .b_zero_point(8)
            $if KERNELTYPE in ['qb4w']:
              .bl(32)
        , test_func, isa_check));
  }
  gemm_tests.push_back(GemmTestParams(
      "k_eq_" + kbs + "_subtile",
      tester.clone()
          .k(k_block).iterations(1)
          $if KERNELTYPE in ['qb4w', 'qc4w']:
            .b_zero_point(8)
          $if KERNELTYPE in ['qb4w']:
            .bl(32)
      , test_func, isa_check)
      .loop_n(1, nr)
      .loop_m(1, mr));
  gemm_tests.push_back(GemmTestParams(
      "k_eq_" + kbs + "_subtile_m",
      tester.clone()
          .n(nr).k(k_block).iterations(1)
          $if KERNELTYPE in ['qb4w', 'qc4w']:
            .b_zero_point(8)
          $if KERNELTYPE in ['qb4w']:
            .bl(32)
      , test_func, isa_check)
      .loop_m(1, mr));
  gemm_tests.push_back(GemmTestParams(
      "k_eq_" + kbs + "_subtile_n",
      tester.clone()
          .m(mr).k(k_block).iterations(1)
          $if KERNELTYPE in ['qb4w', 'qc4w']:
            .b_zero_point(8)
          $if KERNELTYPE in ['qb4w']:
            .bl(32)
      , test_func, isa_check)
      .loop_n(1, nr));
  $if IS_PIPELINED:
    gemm_tests.push_back(GemmTestParams(
        "k_eq_" + kb2s,
        tester.clone()
          .m(mr).n(nr).k(k_block * 2)
          $if KERNELTYPE in ['qb4w', 'qc4w']:
            .b_zero_point(8)
          $if KERNELTYPE in ['qb4w']:
            .bl(32)
      , test_func, isa_check));
    if (!is_igemm) {
      gemm_tests.push_back(GemmTestParams(
          "k_eq_" + kb2s + "_strided_a",
          tester.clone()
              .m(mr).n(nr).k(k_block * 2)
              .a_stride(xnnpack::NextPrime(k_block * 2 + 1))
              $if KERNELTYPE in ['qb4w', 'qc4w']:
                .b_zero_point(8)
            $if KERNELTYPE in ['qb4w']:
              .bl(32)
          , test_func, isa_check));
    }
    gemm_tests.push_back(GemmTestParams(
        "k_eq_" + kb2s + "_subtile",
        tester.clone()
            .k(k_block * 2).iterations(1)
            $if KERNELTYPE in ['qb4w', 'qc4w']:
              .b_zero_point(8)
            $if KERNELTYPE in ['qb4w']:
              .bl(32)
        , test_func, isa_check)
        .loop_n(1, nr)
        .loop_m(1, mr));
  $if KERNELTYPE not in ['qb4w']:
      if (k_block > 1) {
        gemm_tests.push_back(GemmTestParams(
            "k_lt_" + akbs,
            tester.clone()
                .m(mr).n(nr)
                $if KERNELTYPE in ['qb4w', 'qc4w']:
                  .b_zero_point(8)
                $if KERNELTYPE in ['qb4w']:
                  .bl(32)
            , test_func, isa_check)
            .loop_k(1, adj_k_block - 1));
        if (!is_igemm) {
          gemm_tests.push_back(GemmTestParams(
              "k_lt_" + akbs + "_strided_a",
              tester.clone()
                  .m(mr).n(nr)
                  .a_stride(xnnpack::NextPrime(adj_k_block + 1))
                  $if KERNELTYPE in ['qb4w', 'qc4w']:
                    .b_zero_point(8)
                  $if KERNELTYPE in ['qb4w']:
                    .bl(32)
              , test_func, isa_check)
              .loop_k(1, adj_k_block - 1));
        }
        gemm_tests.push_back(GemmTestParams(
            "k_lt_" + akbs + "_subtile",
            tester.clone()
                .iterations(1)
                $if KERNELTYPE in ['qb4w', 'qc4w']:
                  .b_zero_point(8)
                $if KERNELTYPE in ['qb4w']:
                  .bl(32)
            , test_func, isa_check)
            .loop_k(1, adj_k_block - 1)
            .loop_n(1, nr)
            .loop_m(1, mr));
      }
      gemm_tests.push_back(GemmTestParams(
          "k_gt_" + akbs,
          tester.clone()
              .m(mr).n(nr)
              $if KERNELTYPE in ['qb4w', 'qc4w']:
                .b_zero_point(8)
              $if KERNELTYPE in ['qb4w']:
                .bl(32)
          , test_func, isa_check)
          .loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block));
      if (is_igemm) {
        gemm_tests.push_back(GemmTestParams(
            "k_gt_" + akbs + "_strided_a",
            tester.clone()
                .m(mr).n(nr)
                .a_stride(xnnpack::NextPrime(adj_k_block * 2 + 1))
                $if KERNELTYPE in ['qb4w', 'qc4w']:
                  .b_zero_point(8)
              $if KERNELTYPE in ['qb4w']:
                .bl(32)
          , test_func, isa_check)
          .loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block));
      }
      gemm_tests.push_back(GemmTestParams(
          "k_gt_" + akbs + "_subtile",
          tester.clone()
              .iterations(1)
              $if KERNELTYPE in ['qb4w', 'qc4w']:
                .b_zero_point(8)
              $if KERNELTYPE in ['qb4w']:
                .bl(32)
          , test_func, isa_check)
          .loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block)
          .loop_n(1, nr)
          .loop_m(1, mr));
      if (k_block > 1) {
        gemm_tests.push_back(GemmTestParams(
            "k_div_" + kbs,
            tester.clone()
                .m(mr).n(nr)
                $if KERNELTYPE in ['qb4w', 'qc4w']:
                  .b_zero_point(8)
                $if KERNELTYPE in ['qb4w']:
                  .bl(32)
            , test_func, isa_check)
            .loop_k(adj_k_block + k_block, k_block * 5, k_block));
        if (is_igemm) {
          gemm_tests.push_back(GemmTestParams(
              "k_div_" + kbs + "_strided_a",
              tester.clone()
                  .m(mr).n(nr)
                  .a_stride(xnnpack::NextPrime(k_block * 3 + 1))
                  $if KERNELTYPE in ['qb4w', 'qc4w']:
                    .b_zero_point(8)
                  $if KERNELTYPE in ['qb4w']:
                    .bl(32)
              , test_func, isa_check)
              .loop_k(adj_k_block + k_block, k_block * 3, k_block));
        }
        gemm_tests.push_back(GemmTestParams(
            "k_div_" + kbs + "_subtile",
            tester.clone()
                .iterations(1)
                $if KERNELTYPE in ['qb4w', 'qc4w']:
                  .b_zero_point(8)
                $if KERNELTYPE in ['qb4w']:
                  .bl(32)
            , test_func, isa_check)
            .loop_k(adj_k_block + k_block, k_block * 5, k_block)
            .loop_n(1, nr)
            .loop_m(1, mr));
      }
      gemm_tests.push_back(GemmTestParams(
          "n_gt_" + nrs,
          tester.clone()
              .m(mr)
              $if KERNELTYPE in ['qb4w', 'qc4w']:
                .b_zero_point(8)
              $if KERNELTYPE in ['qb4w']:
                .bl(32)
          , test_func, isa_check)
          $if NR_SCALE != "":
            .loop_n(nr + 1, nr * 2 - 1, 4)
          $else:
            .loop_n(nr + 1, nr * 2 - 1)
          .loop_k(1, k_block * 3, k_block + 1));
      $if DATATYPE != "qp8":
        gemm_tests.push_back(GemmTestParams(
            "n_gt_" + nrs + "_strided_cn",
            tester.clone()
                .m(mr)
                .cn_stride(xnnpack::NextPrime(nr + 1))
                $if KERNELTYPE in ['qb4w', 'qc4w']:
                  .b_zero_point(8)
                $if KERNELTYPE in ['qb4w']:
                  .bl(32)
                $if KERNELTYPE in ['qb4w']:
                  .bl(32)
            , test_func, isa_check)
            $if NR_SCALE != "":
              .loop_n(nr + 1, nr * 2 - 1, 4)
            $else:
              .loop_n(nr + 1, nr * 2 - 1)
            .loop_k(1, k_block * 3, k_block + 1));
      if (!is_igemm) {
        gemm_tests.push_back(GemmTestParams(
            "n_gt_" + nrs + "_strided_a",
            tester.clone()
                .m(mr)
                .a_stride(xnnpack::NextPrime(k_block * 3 + 1))
                $if KERNELTYPE in ['qb4w', 'qc4w']:
                  .b_zero_point(8)
                $if KERNELTYPE in ['qb4w']:
                  .bl(32)
            , test_func, isa_check)
            $if NR_SCALE != "":
              .loop_n(nr + 1, nr * 2 - 1, 4)
            $else:
              .loop_n(nr + 1, nr * 2 - 1)
            .loop_k(1, k_block * 3, k_block));
      }
      gemm_tests.push_back(GemmTestParams(
          "n_gt_" + nrs + "_subtile",
          tester.clone()
              .iterations(1)
              $if KERNELTYPE in ['qb4w', 'qc4w']:
                .b_zero_point(8)
              $if KERNELTYPE in ['qb4w']:
                .bl(32)
          , test_func, isa_check)
          $if NR_SCALE != "":
            .loop_n(nr + 1, nr * 2 - 1, 4)
          $else:
            .loop_n(nr + 1, nr * 2 - 1)
          .loop_k(1, k_block * 3, k_block + 1)
          .loop_m(1, mr));
      gemm_tests.push_back(GemmTestParams(
          "n_div_" + nrs,
          tester.clone()
              .m(mr)
              $if KERNELTYPE in ['qb4w', 'qc4w']:
                .b_zero_point(8)
              $if KERNELTYPE in ['qb4w']:
                .bl(32)
          , test_func, isa_check)
          .loop_n(nr * 2, nr * 3, nr)
          .loop_k(1, k_block * 3, k_block + 1));
      $if DATATYPE != "qp8":
        gemm_tests.push_back(GemmTestParams(
            "n_div_" + nrs + "_strided_cn",
            tester.clone()
                .m(mr)
                .cn_stride(xnnpack::NextPrime(nr + 1))
                $if KERNELTYPE in ['qb4w', 'qc4w']:
                  .b_zero_point(8)
                $if KERNELTYPE in ['qb4w']:
                  .bl(32)
            , test_func, isa_check)
            .loop_n(nr * 2, nr * 3, nr)
            .loop_k(1, k_block * 3, k_block + 1));
      if (!is_igemm) {
        gemm_tests.push_back(GemmTestParams(
            "n_div_" + nrs + "_strided_a",
            tester.clone()
                .m(mr)
                .a_stride(xnnpack::NextPrime(k_block * 3 + 1))
                $if KERNELTYPE in ['qb4w', 'qc4w']:
                  .b_zero_point(8)
                $if KERNELTYPE in ['qb4w']:
                  .bl(32)
            , test_func, isa_check)
            .loop_n(nr * 2, nr * 3, nr)
            .loop_k(1, k_block * 3, k_block));
      }
      gemm_tests.push_back(GemmTestParams(
          "n_div_" + nrs + "_subtile",
          tester.clone()
              .iterations(1)
              $if KERNELTYPE in ['qb4w', 'qc4w']:
                .b_zero_point(8)
              $if KERNELTYPE in ['qb4w']:
                .bl(32)
          , test_func, isa_check)
          .loop_n(nr * 2, nr * 3, nr)
          .loop_k(1, k_block * 3, k_block + 1)
          .loop_m(1, mr));
      if (is_igemm) {
        gemm_tests.push_back(GemmTestParams(
            "small_kernel",
            tester.clone()
                .m(mr).n(nr).ks(3)
                $if KERNELTYPE in ['qb4w', 'qc4w']:
                  .b_zero_point(8)
                $if KERNELTYPE in ['qb4w']:
                  .bl(32)
            , test_func, isa_check)
            .loop_k(1, k_block * 3, k_block + 1));
        gemm_tests.push_back(GemmTestParams(
            "small_kernel_subtile",
            tester.clone()
                .ks(3).iterations(1)
                $if KERNELTYPE in ['qb4w', 'qc4w']:
                  .b_zero_point(8)
                $if KERNELTYPE in ['qb4w']:
                  .bl(32)
            , test_func, isa_check)
            .loop_k(1, k_block * 3, k_block + 1)
            .loop_n(1, nr)
            .loop_m(1, mr));
        gemm_tests.push_back(GemmTestParams(
            "n_gt_" + nrs + "_small_kernel",
            tester.clone()
                .m(mr).ks(3)
                $if KERNELTYPE in ['qb4w', 'qc4w']:
                  .b_zero_point(8)
                $if KERNELTYPE in ['qb4w']:
                  .bl(32)
            , test_func, isa_check)
            $if NR_SCALE != "":
              .loop_n(nr + 1, nr * 2 - 1, 4)
            $else:
              .loop_n(nr + 1, nr * 2 - 1)
            .loop_k(1, k_block * 3, k_block + 1));
        gemm_tests.push_back(GemmTestParams(
            "n_div_" + nrs + "_small_kernel",
            tester.clone()
                .m(mr).ks(3)
                $if KERNELTYPE in ['qb4w', 'qc4w']:
                  .b_zero_point(8)
                $if KERNELTYPE in ['qb4w']:
                  .bl(32)
            , test_func, isa_check)
            .loop_n(nr * 2, nr * 3, nr)
            .loop_k(1, k_block * 3, k_block + 1));
      }
      gemm_tests.push_back(GemmTestParams(
          "strided_cm_subtile",
          tester.clone()
              .mr(mr).nr(nr).kr(kr).sr(sr)
              .cm_stride(xnnpack::NextPrime(nr + 1))
              .iterations(1)
              $if KERNELTYPE in ['qb4w', 'qc4w']:
                .b_zero_point(8)
              $if KERNELTYPE in ['qb4w']:
                .bl(32)
          , test_func, isa_check)
          .loop_k(1, k_block * 3, k_block + 1)
          .loop_n(1, nr)
          .loop_m(1, mr));
      if (is_igemm) {
        gemm_tests.push_back(GemmTestParams(
            "a_offset",
            tester.clone()
                .m(mr).n(nr).ks(3)
                .a_offset(xnnpack::NextPrime(mr * k_block * 3 + 1))
                $if KERNELTYPE in ['qb4w', 'qc4w']:
                  .b_zero_point(8)
                $if KERNELTYPE in ['qb4w']:
                  .bl(32)
            , test_func, isa_check)
            .loop_k(1, k_block * 3, k_block + 1));
        gemm_tests.push_back(GemmTestParams(
            "zero",
            tester.clone()
                .m(mr).n(nr).ks(3)
                .a_offset(xnnpack::NextPrime(mr * k_block * 3 + 1))
                $if KERNELTYPE in ['qb4w', 'qc4w']:
                  .b_zero_point(8)
                $if KERNELTYPE in ['qb4w']:
                  .bl(32)
            , test_func, isa_check)
            .loop_k(1, k_block * 3, k_block + 1)
            .loop_zi(0, mr - 1));
      }
      $if ACTIVATION == "MINMAX":
        gemm_tests.push_back(GemmTestParams(
            "qmin",
            tester.clone()
                .m(mr).n(nr).k(k_block).qmin(128)
                $if KERNELTYPE in ['qb4w', 'qc4w']:
                  .b_zero_point(8)
                $if KERNELTYPE in ['qb4w']:
                  .bl(32)
            , test_func, isa_check));
        gemm_tests.push_back(GemmTestParams(
            "qmax",
            tester.clone()
                .m(mr).n(nr).k(k_block).qmax(128)
                $if KERNELTYPE in ['qb4w', 'qc4w']:
                  .b_zero_point(8)
                $if KERNELTYPE in ['qb4w']:
                  .bl(32)
            , test_func, isa_check));
      gemm_tests.push_back(GemmTestParams(
          "strided_cm",
          tester.clone()
              .m(mr).n(nr).k(k_block)
              .cm_stride(xnnpack::NextPrime(nr + 1))
              $if KERNELTYPE in ['qb4w', 'qc4w']:
                .b_zero_point(8)
              $if KERNELTYPE in ['qb4w']:
                .bl(32)
          , test_func, isa_check));
      $if DATATYPE == "qu8":
        gemm_tests.push_back(GemmTestParams(
            "no_a_zero_point",
            tester.clone()
                .m(mr).n(nr).a_zero_point(0)
            , test_func, isa_check)
            .loop_k(1, k_block * 3, k_block + 1));
      $if DATATYPE == "qu8":
        gemm_tests.push_back(GemmTestParams(
            "no_b_zero_point",
            tester.clone()
                .m(mr).n(nr).b_zero_point(0)
            , test_func, isa_check)
            .loop_k(1, k_block * 3, k_block + 1));
        gemm_tests.push_back(GemmTestParams(
            "b_zero_point",
            tester.clone()
                .m(mr).n(nr).k(k_block)
            , test_func, isa_check)
            .loop_bzp(0, 255));
        gemm_tests.push_back(GemmTestParams(
            "no_zero_point",
            tester.clone()
                .m(mr).n(nr)
                .a_zero_point(0)
                .b_zero_point(0)
            , test_func, isa_check)
            .loop_k(1, k_block * 3, k_block + 1));
  $if KERNELTYPE in ['qb4w']:
    gemm_tests.push_back(GemmTestParams(
        "bl",
        tester.clone()
            .m(mr).n(nr).k(k_block * 12)
            .b_zero_point(8)
        , test_func, isa_check)
        .loop_k(k_block, k_block * 12, k_block, LoopStepType::Linear)
        .loop_bl(32, k_block * 32, 32));

  return gemm_tests;
}
"""

GEMM_TEST_CODE = """\
$if CPP_CHECK:
  #if ${CPP_CHECK}
INSTANTIATE_TEST_SUITE_P(
    ${TEST_NAME}, GemmTest,
    testing::ValuesIn(CreateTests(
        /*k_block=*/${KBLOCK},
        /*adj_k_block=*/${ADJKBLOCK},
        /*mr=*/${MR}, /*nr=*/${NR}, /*kr=*/${KR}, /*sr=*/${SR},
        $if DATATYPE in ('qp8',):
          /*mr_packed=*/${MR_PACKED},
        /*is_igemm=*/${"true" if UKERNEL_TYPE.startswith("IGEMM") else "false"},
        /*unsigned_inputs=*/${"true" if UNSIGNED_INPUTS else "false"},
        [](GemmMicrokernelTester& tester) {
          tester.Test(${",\\n                      ".join(TEST_ARGS)});
        $if ISA_CHECK:
          },
          []() {
            ${ISA_CHECK};
          })),
        $else:
          })),
    [](const testing::TestParamInfo<GemmTest::ParamType>& info) {
      return info.param.test_name;
    });

$if TEST_NAME.startswith('GENERATE') and DATATYPE in ['f32', 'f16']:
  TEST(${TEST_NAME}, subtile_m_upto_mr) {
    $if ISA_CHECK:
      ${ISA_CHECK};
    for (uint32_t max_mr = 1; max_mr <= ${MR}; max_mr++) {
      for (uint32_t m = 1; m <= max_mr; m++) {
        for (size_t k = 1; k <= ${KBLOCK * 2}; k += 1) {
          GemmMicrokernelTester()
            .mr(max_mr)
            $if NR > 1:
              .nr(${NR})
            $if KR > 1:
              .kr(${KR})
            $if SR > 1:
              .sr(${SR})
            .m(m)
            $if NR > 1:
              .n(${NR})
            .k(k)
            .iterations(1)
            $if KERNELTYPE in ['qb4w', 'qc4w']:
              .b_zero_point(8)
            .Test(${", ".join(TEST_ARGS)});
        }
      }
    }
  }

$if TEST_NAME.startswith('GENERATE') and DATATYPE in ['f32', 'f16'] and PROTOTYPE is not None:
  #if XNN_ENABLE_ASSEMBLY
    TEST(${TEST_NAME}, matches_assembly) {
      $if ISA_CHECK:
        ${ISA_CHECK};
      GemmMicrokernelTester()
        $if MR > 1:
          .mr(${MR})
        $if NR > 1:
          .nr(${NR})
        $if KR > 1:
          .kr(${KR})
        $if SR > 1:
          .sr(${SR})
        $if MR > 1:
          .m(${MR})
        $if NR > 1:
          .n(${NR})
        .k(${KBLOCK})
        .Test(
            ${", ".join(TEST_ARGS)},
            &${PROTOTYPE});
    }
  #endif // XNN_ENABLE_ASSEMBLY
$if CPP_CHECK:
  #endif  // ${CPP_CHECK}
"""


def generate_test_cases(
    ukernel,
    mr,
    nr,
    kr,
    sr,
    mr_packed,
    k_block,
    unsigned_inputs,
    vector_tile,
    init_fn,
    pack_fn,
    packed_stride_fn,
    requantization,
    is_pipelined,
    cpp_check,
    isa,
    prototype,
):
  """Generates all tests cases for a GEMM micro-kernel.

  Args:
    ukernel: C name of the micro-kernel function.
    mr: MR parameter of the GEMM micro-kernel.
    nr: NR parameter of the GEMM micro-kernel.
    kr: KR parameter of the GEMM micro-kernel.
    sr: SR parameter of the GEMM micro-kernel.
    mr_packed: Optional MR parameter for the left-hand packing function.
    k_block: Number of K values processed per one iteration of the main loop of
      the micro-kernel.
    unsigned_inputs: whether the inputs should be converted to unsigned
      integers. Some microkernels are more efficient with unsigned inputs.
    vector_tile: Indicates if vector tile for NR is specified in vectors rather
      than elements.
    init_fn: C name of the function to initialize microkernel parameters.
    pack_fn: C name of the function to pack the weights.
    packed_stride_fn: C name of the function to compute the packed weights
      stride.
    requantization: name of the requantization scheme used by the microkernel.
    is_pipelined: Indicates if the micro-kernel is implemented with software
      pipelining. Additional test cases are generated for software pipelined
      micro-kernels to separately test prologue + epiloque of the pipelined loop
      and iteration of the pipelined loop.
    cpp_check: Optional preprocessor macro to check for the availability of the
      micro-kernel.
    isa: instruction set required to run the micro-kernel. Generated unit test
      will skip execution if the host processor doesn't support this ISA.

  Returns:
    Code for the test case.
  """
  _, ukernel_name = ukernel.split("_", 1)

  _, datatype, ukernel_type, activation, _ = ukernel.split("_", 4)
  kerneltype = datatype
  if datatype in ["f16", "f32"] and ukernel_type in ["qc8w", "qc4w"]:
    _, datatype, kerneltype, ukernel_type, activation, _ = ukernel.split(
        "_", 5
    )
    datatype = datatype + "_" + kerneltype
  if (
      datatype in ("qd8", "qp8")
      and ukernel_type in ["f16", "f32"]
      and activation in ["qc8w", "qc4w", "qb4w"]
  ):
    _, datatype, _, kerneltype, ukernel_type, activation, _ = ukernel.split(
        "_", 6
    )

  if activation == "ukernel":
    activation = "linear"
  if activation in ["qs8w"]:
    _, _, _, _, _, activation, _ = ukernel.split("_", 6)
  test_args = [ukernel]
  if init_fn:
    test_args.append(init_fn)

  if pack_fn:
    test_args.append(pack_fn)
  if packed_stride_fn:
    test_args.append(packed_stride_fn)

  if init_fn and requantization:
    requantization_datatype = {"qc8": "qs8"}.get(datatype, datatype)
    test_args.append(
        "xnn_%s_requantize_%s" % (requantization_datatype, requantization)
    )

  nr_scale = ""
  if vector_tile:
    ctype = {
        "qs8": "int8_t",
        "qd8": "int32_t",
        "qp8": "int8_t",
        "qu8": "uint8_t",
        "f16": "uint16_t",
        "f32": "float",
    }[datatype]
    nr_scale = {"rvv": " * xnn_init_hardware_config()->vlenb / sizeof(%s)" % ctype}[isa]
  test_args = {
      "TEST_NAME": ukernel_name.upper().replace("UKERNEL_", ""),
      "TEST_ARGS": test_args,
      "UKERNEL_TYPE": ukernel_type.upper(),
      "DATATYPE": datatype,
      "KERNELTYPE": kerneltype,
      "ACTIVATION": activation.upper(),
      "MR": mr,
      "NR": nr,
      "KR": kr,
      "SR": sr,
      "MR_PACKED": mr_packed,
      "KBLOCK": k_block,
      "UNSIGNED_INPUTS": unsigned_inputs,
      "NR_SCALE": nr_scale,
      "ADJKBLOCK": 2 * k_block if is_pipelined else k_block,
      "IS_PIPELINED": is_pipelined,
      "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
      "next_prime": next_prime,
      "PROTOTYPE": prototype,
      "CPP_CHECK": cpp_check,
  }

  create_test_case = xngen.preprocess(GEMM_CREATE_TESTS_CODE, test_args)

  test_case = xngen.preprocess(GEMM_TEST_CODE, test_args)

  benchmark = xngen.preprocess(
      GEMM_BENCH_CODE,
      {
          "UKERNEL_NAME": ukernel_name,
          "GEMM": ukernel,
          "KERNELTYPE": kerneltype,
          "DATATYPE": datatype,
          "INIT_PARAMS": init_fn,
          "PACK_FN": pack_fn,
          "PACKED_STRIDE_FN": packed_stride_fn,
          "MR": mr,
          "NR": nr,
          "KR": kr,
          "SR": sr,
          "MR_PACKED": mr_packed,
          "NR_SCALE": nr_scale,
          "ISA_CHECK": xnncommon.generate_isa_utilcheck_macro(isa),
          "CPP_CHECK": cpp_check,
      },
  )
  return create_test_case, test_case, benchmark


def main(args):
  options = parser.parse_args(args)
  num_output_files = len(options.output_test)

  with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
    spec_yaml = yaml.safe_load(spec_file)
    if not isinstance(spec_yaml, list):
      raise ValueError("expected a list of micro-kernels in the spec")

    tests = """\
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
//
// Auto-generated file. Do not edit!
//   Specification: {specification}
//   Generator: {generator}

#include <cstddef>
#include <functional>
#include <string>
#include <vector>

#include <gtest/gtest.h>
#include "xnnpack/allocator.h"
#include "xnnpack/common.h"
#include "xnnpack/gemm.h"
#include "xnnpack/igemm.h"
#include "xnnpack/isa-checks.h"
#include "xnnpack/microparams-init.h"
#include "xnnpack/pack.h"
#include "xnnpack/packw.h"
#include "xnnpack/ppmm.h"
#include "xnnpack/requantization.h"
#include "gemm-microkernel-tester.h"
#include "next_prime.h"
""".format(specification=options.spec, generator=sys.argv[0])

    benches = """\
// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
//
// Auto-generated file. Do not edit!
//   Specification: {specification}
//   Generator: {generator}

#include <benchmark/benchmark.h>
#include "gemm-benchmark.h"
#include "utils.h"
#include "xnnpack/common.h"
#include "xnnpack/gemm.h"
#include "xnnpack/isa-checks.h"
#include "xnnpack/microfnptr.h"
#include "xnnpack/microparams-init.h"
#include "xnnpack/pack.h"
#include "xnnpack/packw.h"
""".format(specification=options.spec, generator=sys.argv[0])

    test_outputs = collections.defaultdict(str)
    bench_outputs = benches

    isa_hierarchy = xnncommon.isa_hierarchy_map()

    # Cached `CreateTests` functions.
    idx_from_create_tests_hash = collections.defaultdict(
        lambda: len(idx_from_create_tests_hash) + 1
    )
    create_tests_from_idx = {}

    benches = [""] * len(isa_hierarchy)
    for ukernel_spec in spec_yaml:
      name = ukernel_spec["name"]
      k_block = int(ukernel_spec["k-block"])
      if "unsigned-inputs" in ukernel_spec:
        unsigned_inputs = int(ukernel_spec["unsigned-inputs"])
      else:
        unsigned_inputs = False
      init_fn = ukernel_spec.get("init")
      pack_fn = ukernel_spec.get("pack")
      packed_stride_fn = ukernel_spec.get("packed-stride")
      pipelined = bool(ukernel_spec.get("pipelined", False))
      cpp_check = ukernel_spec.get("cpp-check", False)
      prototype = ukernel_spec.get("prototype")
      (
          mr,
          nr,
          kr,
          sr,
          mr_packed,
          vector_tile,
          requantization,
          arch,
          isa,
          assembly,
      ) = split_ukernel_name(name)

      create_tests, test_case, bench_case = generate_test_cases(
          name,
          mr,
          nr,
          kr,
          sr,
          mr_packed,
          k_block,
          unsigned_inputs,
          vector_tile,
          init_fn,
          pack_fn,
          packed_stride_fn,
          requantization,
          pipelined,
          cpp_check,
          isa,
          prototype,
      )

      # Store or reuse the `CreateTests` function?
      create_tests_hash = hash(create_tests)
      create_tests_idx = idx_from_create_tests_hash[create_tests_hash]
      if create_tests_idx not in create_tests_from_idx:
        create_tests_from_idx[create_tests_idx] = create_tests.replace(
            "CreateTests(", f"CreateTests{create_tests_idx}("
        )
        if isa == 'rvv':
          create_tests_from_idx[create_tests_idx] = xnncommon.postprocess_test_case(
            create_tests_from_idx[create_tests_idx], arch, isa, assembly)
      test_case = test_case.replace(
          "CreateTests(", f"CreateTests{create_tests_idx}("
      )

      # Hash the name of each microkernel and figure out which output file to
      # write it to.
      output_index = zlib.crc32(bytes(name, "utf-8")) % num_output_files
      test_outputs[
          options.output_test[output_index]
      ] += "\n\n" + xnncommon.postprocess_test_case(
          test_case, arch, isa, assembly
      )
      benches[
          isa_hierarchy.get(isa, 0)
      ] += "\n\n" + xnncommon.postprocess_test_case(
          bench_case, arch, isa, assembly
      )

    for arch_idx in reversed(range(len(isa_hierarchy))):
      bench_outputs += benches[arch_idx]

    bench_outputs += """\n
#ifndef XNNPACK_BENCHMARK_NO_MAIN
BENCHMARK_MAIN();
#endif
"""

    if options.output_bench:
      # Strip out consecutive preprocessor `endif`/`if` pairs.
      for _ in range(2):
        bench_outputs = re.sub(
            r"^ *\#endif  // ([^\n]+)\n+ *\#if \1\n",
            "\n",
            bench_outputs,
            flags=re.MULTILINE,
        )
      output_name = options.output_bench
      xnncommon.overwrite_if_changed(output_name, bench_outputs)

    create_tests = (
        "namespace {\n\n"
        + "\n".join(create_tests_from_idx.values())
        + "\n}  // namespace\n"
    )
    test_outputs = {
        k: tests + "\n" + create_tests + v for k, v in test_outputs.items()
    }

    # Strip out consecutive preprocessor `endif`/`if` pairs.
    for _ in range(2):
      test_outputs = {
          k: re.sub(
              r"^ *\#endif  // ([^\n]+)\n+ *\#if \1\n",
              "\n",
              v,
              flags=re.MULTILINE,
          )
          for k, v in test_outputs.items()
      }

    for output_name in options.output_test:
      xnncommon.overwrite_if_changed(output_name, test_outputs[output_name])


if __name__ == "__main__":
  main(sys.argv[1:])
