# # SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # load( "//:kai_defs.bzl", "kai_c_library", "kai_cpu_bf16", "kai_cpu_dotprod", "kai_cpu_fp16", "kai_cpu_i8mm", "kai_cpu_neon", "kai_cpu_sme", "kai_cpu_sme2", ) package(default_visibility = ["//visibility:private"]) # buildifier: keep sorted SCALAR_KERNELS = [ "pack/kai_lhs_quant_pack_qai8dxp_f32", "pack/kai_lhs_quant_pack_qsi8d32p_f32", "pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0", "pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0", ] # buildifier: keep sorted NEON_KERNELS = [ "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla", "pack/kai_lhs_quant_pack_qsi8d32p_f32_neon", "pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon", "pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0", "pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon", "pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", "pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon", "pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0", "pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon", "pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon", ] NEON_KERNELS_ASM = [ "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla", ] # buildifier: keep sorted FP16_KERNELS = [ "matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla", "pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", ] # buildifier: keep sorted BF16_KERNELS = [ "matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot", "matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla", "pack/kai_lhs_quant_pack_bf16p1x4_f32_neon", "pack/kai_lhs_quant_pack_bf16p8x4_f32_neon", "pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon", ] # buildifier: keep sorted FP16_BF16_KERNELS = [ "matmul_clamp_f16_bf16p_bf16p/kai_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla", "pack/kai_lhs_pack_bf16p8x4_f16_neon", "pack/kai_rhs_pack_kxn_bf16p12x4biasf16_f16_neon", "pack/kai_rhs_pack_kxn_bf16p12x4biasf32_f16_neon", ] # buildifier: keep sorted DOTPROD_KERNELS = [ "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod", ] DOTPROD_KERNELS_ASM = [ "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod", ] # buildifier: keep sorted I8MM_KERNELS = [ "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm", ] I8MM_KERNELS_ASM = [ "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm", ] # buildifier: keep sorted SME_KERNELS = [ "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", "pack/kai_lhs_pack_f32p2vlx1_f32_sme", "pack/kai_lhs_pack_x16p2vlx2_x16_sme", "pack/kai_lhs_pack_x8p2vlx4_x8_sme", "pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme", "pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme", "pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", "pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", "pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme", "pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme", "pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme", ] # buildifier: keep sorted SME2_KERNELS = [ "matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot", "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla", "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla", "matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot", "matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa", "matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", ] kai_c_library( name = "interface", textual_hdrs = glob(["**/*_interface.h"]), visibility = ["//visibility:public"], ) kai_c_library( name = "scalar_impl", srcs = [ukernel + ".c" for ukernel in SCALAR_KERNELS], textual_hdrs = [ukernel + ".h" for ukernel in SCALAR_KERNELS], ) kai_c_library( name = "neon_impl", srcs = [ukernel + ".c" for ukernel in NEON_KERNELS], cpu_uarch = kai_cpu_neon(), textual_hdrs = [ukernel + ".h" for ukernel in NEON_KERNELS], ) kai_c_library( name = "neon_impl_asm", srcs = [ukernel + "_asm.S" for ukernel in NEON_KERNELS_ASM], cpu_uarch = kai_cpu_neon(), ) kai_c_library( name = "fp16_impl", srcs = [ukernel + ".c" for ukernel in FP16_KERNELS], cpu_uarch = kai_cpu_fp16(), textual_hdrs = [ukernel + ".h" for ukernel in FP16_KERNELS], ) kai_c_library( name = "bf16_impl", srcs = [ukernel + ".c" for ukernel in BF16_KERNELS], cpu_uarch = kai_cpu_bf16(), textual_hdrs = [ukernel + ".h" for ukernel in BF16_KERNELS], ) kai_c_library( name = "fp16_bf16_impl", srcs = [ukernel + ".c" for ukernel in FP16_BF16_KERNELS], cpu_uarch = kai_cpu_fp16() + kai_cpu_bf16(), textual_hdrs = [ukernel + ".h" for ukernel in FP16_BF16_KERNELS], ) kai_c_library( name = "dotprod_impl", srcs = [ukernel + ".c" for ukernel in DOTPROD_KERNELS], cpu_uarch = kai_cpu_dotprod(), textual_hdrs = [ukernel + ".h" for ukernel in DOTPROD_KERNELS], ) kai_c_library( name = "dotprod_impl_asm", srcs = [ukernel + "_asm.S" for ukernel in DOTPROD_KERNELS_ASM], cpu_uarch = kai_cpu_dotprod(), ) kai_c_library( name = "i8mm_impl", srcs = [ukernel + ".c" for ukernel in I8MM_KERNELS], cpu_uarch = kai_cpu_i8mm(), textual_hdrs = [ukernel + ".h" for ukernel in I8MM_KERNELS], ) kai_c_library( name = "i8mm_impl_asm", srcs = [ukernel + "_asm.S" for ukernel in I8MM_KERNELS_ASM], cpu_uarch = kai_cpu_i8mm(), ) kai_c_library( name = "sme_impl", srcs = [ukernel + ".c" for ukernel in SME_KERNELS], cpu_uarch = kai_cpu_sme(), textual_hdrs = [ukernel + ".h" for ukernel in SME_KERNELS], ) kai_c_library( name = "sme2_impl", srcs = [ukernel + ".c" for ukernel in SME2_KERNELS], cpu_uarch = kai_cpu_sme2(), textual_hdrs = [ukernel + ".h" for ukernel in SME2_KERNELS], ) kai_c_library( name = "matmul", visibility = ["//visibility:public"], deps = [ ":bf16_impl", ":dotprod_impl", ":dotprod_impl_asm", ":fp16_bf16_impl", ":fp16_impl", ":i8mm_impl", ":i8mm_impl_asm", ":interface", ":neon_impl", ":neon_impl_asm", ":scalar_impl", ":sme2_impl", ":sme_impl", ], )