//// --------------------------
//// ATTENTION:
//// THIS CODE IS AUTOGENERATED
//// BY sve_emblookup_codegen.py
//// DO NOT MODIFY!!!
//// --------------------------

#include <arm_sve.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <cstdint>
#include <cstring>
namespace caffe2 {

template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int32_t_float_float__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const float* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  const svbool_t svAll = svptrue_b32();
  const auto vLen = static_cast<int64_t>(svcntw());
  int64_t pos = 0;
  if (block_size == 32 * vLen) {
    // unrolling 32 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      svfloat32_t vsum8 = svdup_n_f32(0);
      svfloat32_t vsum9 = svdup_n_f32(0);
      svfloat32_t vsum10 = svdup_n_f32(0);
      svfloat32_t vsum11 = svdup_n_f32(0);
      svfloat32_t vsum12 = svdup_n_f32(0);
      svfloat32_t vsum13 = svdup_n_f32(0);
      svfloat32_t vsum14 = svdup_n_f32(0);
      svfloat32_t vsum15 = svdup_n_f32(0);
      svfloat32_t vsum16 = svdup_n_f32(0);
      svfloat32_t vsum17 = svdup_n_f32(0);
      svfloat32_t vsum18 = svdup_n_f32(0);
      svfloat32_t vsum19 = svdup_n_f32(0);
      svfloat32_t vsum20 = svdup_n_f32(0);
      svfloat32_t vsum21 = svdup_n_f32(0);
      svfloat32_t vsum22 = svdup_n_f32(0);
      svfloat32_t vsum23 = svdup_n_f32(0);
      svfloat32_t vsum24 = svdup_n_f32(0);
      svfloat32_t vsum25 = svdup_n_f32(0);
      svfloat32_t vsum26 = svdup_n_f32(0);
      svfloat32_t vsum27 = svdup_n_f32(0);
      svfloat32_t vsum28 = svdup_n_f32(0);
      svfloat32_t vsum29 = svdup_n_f32(0);
      svfloat32_t vsum30 = svdup_n_f32(0);
      svfloat32_t vsum31 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const float* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0);
        vsum1 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1);
        vsum2 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2);
        vsum3 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3);
        vsum4 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4);
        vsum5 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5);
        vsum6 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6);
        vsum7 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7);
        vsum8 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8);
        vsum9 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9);
        vsum10 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10);
        vsum11 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11);
        vsum12 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12);
        vsum13 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13);
        vsum14 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14);
        vsum15 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15);
        vsum16 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[16 * vLen]), vsum16);
        vsum17 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[17 * vLen]), vsum17);
        vsum18 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[18 * vLen]), vsum18);
        vsum19 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[19 * vLen]), vsum19);
        vsum20 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[20 * vLen]), vsum20);
        vsum21 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[21 * vLen]), vsum21);
        vsum22 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[22 * vLen]), vsum22);
        vsum23 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[23 * vLen]), vsum23);
        vsum24 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[24 * vLen]), vsum24);
        vsum25 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[25 * vLen]), vsum25);
        vsum26 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[26 * vLen]), vsum26);
        vsum27 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[27 * vLen]), vsum27);
        vsum28 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[28 * vLen]), vsum28);
        vsum29 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[29 * vLen]), vsum29);
        vsum30 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[30 * vLen]), vsum30);
        vsum31 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[31 * vLen]), vsum31);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
        svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv));
        svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv));
        svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv));
        svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv));
        svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv));
        svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv));
        svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv));
        svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv));
        svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv));
        svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv));
        svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv));
        svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv));
        svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv));
        svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv));
        svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv));
        svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv));
        svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv));
        svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv));
        svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv));
        svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv));
        svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv));
        svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv));
        svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv));
        svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
        svst1_f32(svAll, &op[8 * vLen], vsum8);
        svst1_f32(svAll, &op[9 * vLen], vsum9);
        svst1_f32(svAll, &op[10 * vLen], vsum10);
        svst1_f32(svAll, &op[11 * vLen], vsum11);
        svst1_f32(svAll, &op[12 * vLen], vsum12);
        svst1_f32(svAll, &op[13 * vLen], vsum13);
        svst1_f32(svAll, &op[14 * vLen], vsum14);
        svst1_f32(svAll, &op[15 * vLen], vsum15);
        svst1_f32(svAll, &op[16 * vLen], vsum16);
        svst1_f32(svAll, &op[17 * vLen], vsum17);
        svst1_f32(svAll, &op[18 * vLen], vsum18);
        svst1_f32(svAll, &op[19 * vLen], vsum19);
        svst1_f32(svAll, &op[20 * vLen], vsum20);
        svst1_f32(svAll, &op[21 * vLen], vsum21);
        svst1_f32(svAll, &op[22 * vLen], vsum22);
        svst1_f32(svAll, &op[23 * vLen], vsum23);
        svst1_f32(svAll, &op[24 * vLen], vsum24);
        svst1_f32(svAll, &op[25 * vLen], vsum25);
        svst1_f32(svAll, &op[26 * vLen], vsum26);
        svst1_f32(svAll, &op[27 * vLen], vsum27);
        svst1_f32(svAll, &op[28 * vLen], vsum28);
        svst1_f32(svAll, &op[29 * vLen], vsum29);
        svst1_f32(svAll, &op[30 * vLen], vsum30);
        svst1_f32(svAll, &op[31 * vLen], vsum31);
      }
    }
  } else if (block_size == 16 * vLen) {
    // unrolling 16 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      svfloat32_t vsum8 = svdup_n_f32(0);
      svfloat32_t vsum9 = svdup_n_f32(0);
      svfloat32_t vsum10 = svdup_n_f32(0);
      svfloat32_t vsum11 = svdup_n_f32(0);
      svfloat32_t vsum12 = svdup_n_f32(0);
      svfloat32_t vsum13 = svdup_n_f32(0);
      svfloat32_t vsum14 = svdup_n_f32(0);
      svfloat32_t vsum15 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const float* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0);
        vsum1 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1);
        vsum2 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2);
        vsum3 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3);
        vsum4 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4);
        vsum5 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5);
        vsum6 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6);
        vsum7 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7);
        vsum8 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8);
        vsum9 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9);
        vsum10 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10);
        vsum11 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11);
        vsum12 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12);
        vsum13 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13);
        vsum14 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14);
        vsum15 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
        svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv));
        svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv));
        svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv));
        svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv));
        svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv));
        svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv));
        svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv));
        svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
        svst1_f32(svAll, &op[8 * vLen], vsum8);
        svst1_f32(svAll, &op[9 * vLen], vsum9);
        svst1_f32(svAll, &op[10 * vLen], vsum10);
        svst1_f32(svAll, &op[11 * vLen], vsum11);
        svst1_f32(svAll, &op[12 * vLen], vsum12);
        svst1_f32(svAll, &op[13 * vLen], vsum13);
        svst1_f32(svAll, &op[14 * vLen], vsum14);
        svst1_f32(svAll, &op[15 * vLen], vsum15);
      }
    }
  } else if (block_size == 8 * vLen) {
    // unrolling 8 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const float* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0);
        vsum1 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1);
        vsum2 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2);
        vsum3 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3);
        vsum4 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4);
        vsum5 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5);
        vsum6 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6);
        vsum7 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
      }
    }
  } else if (block_size == 4 * vLen) {
    // unrolling 4 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const float* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0);
        vsum1 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1);
        vsum2 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2);
        vsum3 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
      }
    }
  } else if (block_size == 2 * vLen) {
    // unrolling 2 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const float* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0);
        vsum1 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
      }
    }
  } else {
    // generic code:
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      memset(op, 0, sizeof(float) * block_size);
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const float* ip = &input[idx * block_size];
        svbool_t pg;
        for (int64_t k = 0;
             svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size));
             k += vLen) {
          svst1_f32(
              pg,
              &op[k],
              svmad_f32_x(
                  pg, vwgt, svld1_f32(pg, &ip[k]), svld1_f32(pg, &op[k])));
        }

        ++pos;
      }
      const int64_t length = end_offset - start_offset;

      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svbool_t pg;
        for (int64_t j = 0;
             svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size));
             j += vLen) {
          svst1_f32(
              pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv));
        }
      }
    }
  }
  return pos == index_size;
}
bool EmbeddingLookupIdx_int32_t_float_float_false__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const float* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int32_t_float_float__sve<false>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}
bool EmbeddingLookupIdx_int32_t_float_float_true__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const float* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int32_t_float_float__sve<true>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}

template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int64_t_float_float__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const float* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  const svbool_t svAll = svptrue_b32();
  const auto vLen = static_cast<int64_t>(svcntw());
  int64_t pos = 0;
  if (block_size == 32 * vLen) {
    // unrolling 32 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      svfloat32_t vsum8 = svdup_n_f32(0);
      svfloat32_t vsum9 = svdup_n_f32(0);
      svfloat32_t vsum10 = svdup_n_f32(0);
      svfloat32_t vsum11 = svdup_n_f32(0);
      svfloat32_t vsum12 = svdup_n_f32(0);
      svfloat32_t vsum13 = svdup_n_f32(0);
      svfloat32_t vsum14 = svdup_n_f32(0);
      svfloat32_t vsum15 = svdup_n_f32(0);
      svfloat32_t vsum16 = svdup_n_f32(0);
      svfloat32_t vsum17 = svdup_n_f32(0);
      svfloat32_t vsum18 = svdup_n_f32(0);
      svfloat32_t vsum19 = svdup_n_f32(0);
      svfloat32_t vsum20 = svdup_n_f32(0);
      svfloat32_t vsum21 = svdup_n_f32(0);
      svfloat32_t vsum22 = svdup_n_f32(0);
      svfloat32_t vsum23 = svdup_n_f32(0);
      svfloat32_t vsum24 = svdup_n_f32(0);
      svfloat32_t vsum25 = svdup_n_f32(0);
      svfloat32_t vsum26 = svdup_n_f32(0);
      svfloat32_t vsum27 = svdup_n_f32(0);
      svfloat32_t vsum28 = svdup_n_f32(0);
      svfloat32_t vsum29 = svdup_n_f32(0);
      svfloat32_t vsum30 = svdup_n_f32(0);
      svfloat32_t vsum31 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const float* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0);
        vsum1 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1);
        vsum2 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2);
        vsum3 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3);
        vsum4 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4);
        vsum5 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5);
        vsum6 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6);
        vsum7 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7);
        vsum8 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8);
        vsum9 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9);
        vsum10 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10);
        vsum11 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11);
        vsum12 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12);
        vsum13 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13);
        vsum14 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14);
        vsum15 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15);
        vsum16 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[16 * vLen]), vsum16);
        vsum17 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[17 * vLen]), vsum17);
        vsum18 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[18 * vLen]), vsum18);
        vsum19 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[19 * vLen]), vsum19);
        vsum20 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[20 * vLen]), vsum20);
        vsum21 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[21 * vLen]), vsum21);
        vsum22 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[22 * vLen]), vsum22);
        vsum23 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[23 * vLen]), vsum23);
        vsum24 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[24 * vLen]), vsum24);
        vsum25 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[25 * vLen]), vsum25);
        vsum26 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[26 * vLen]), vsum26);
        vsum27 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[27 * vLen]), vsum27);
        vsum28 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[28 * vLen]), vsum28);
        vsum29 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[29 * vLen]), vsum29);
        vsum30 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[30 * vLen]), vsum30);
        vsum31 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[31 * vLen]), vsum31);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
        svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv));
        svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv));
        svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv));
        svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv));
        svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv));
        svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv));
        svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv));
        svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv));
        svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv));
        svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv));
        svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv));
        svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv));
        svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv));
        svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv));
        svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv));
        svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv));
        svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv));
        svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv));
        svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv));
        svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv));
        svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv));
        svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv));
        svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv));
        svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
        svst1_f32(svAll, &op[8 * vLen], vsum8);
        svst1_f32(svAll, &op[9 * vLen], vsum9);
        svst1_f32(svAll, &op[10 * vLen], vsum10);
        svst1_f32(svAll, &op[11 * vLen], vsum11);
        svst1_f32(svAll, &op[12 * vLen], vsum12);
        svst1_f32(svAll, &op[13 * vLen], vsum13);
        svst1_f32(svAll, &op[14 * vLen], vsum14);
        svst1_f32(svAll, &op[15 * vLen], vsum15);
        svst1_f32(svAll, &op[16 * vLen], vsum16);
        svst1_f32(svAll, &op[17 * vLen], vsum17);
        svst1_f32(svAll, &op[18 * vLen], vsum18);
        svst1_f32(svAll, &op[19 * vLen], vsum19);
        svst1_f32(svAll, &op[20 * vLen], vsum20);
        svst1_f32(svAll, &op[21 * vLen], vsum21);
        svst1_f32(svAll, &op[22 * vLen], vsum22);
        svst1_f32(svAll, &op[23 * vLen], vsum23);
        svst1_f32(svAll, &op[24 * vLen], vsum24);
        svst1_f32(svAll, &op[25 * vLen], vsum25);
        svst1_f32(svAll, &op[26 * vLen], vsum26);
        svst1_f32(svAll, &op[27 * vLen], vsum27);
        svst1_f32(svAll, &op[28 * vLen], vsum28);
        svst1_f32(svAll, &op[29 * vLen], vsum29);
        svst1_f32(svAll, &op[30 * vLen], vsum30);
        svst1_f32(svAll, &op[31 * vLen], vsum31);
      }
    }
  } else if (block_size == 16 * vLen) {
    // unrolling 16 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      svfloat32_t vsum8 = svdup_n_f32(0);
      svfloat32_t vsum9 = svdup_n_f32(0);
      svfloat32_t vsum10 = svdup_n_f32(0);
      svfloat32_t vsum11 = svdup_n_f32(0);
      svfloat32_t vsum12 = svdup_n_f32(0);
      svfloat32_t vsum13 = svdup_n_f32(0);
      svfloat32_t vsum14 = svdup_n_f32(0);
      svfloat32_t vsum15 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const float* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0);
        vsum1 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1);
        vsum2 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2);
        vsum3 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3);
        vsum4 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4);
        vsum5 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5);
        vsum6 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6);
        vsum7 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7);
        vsum8 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8);
        vsum9 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9);
        vsum10 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10);
        vsum11 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11);
        vsum12 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12);
        vsum13 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13);
        vsum14 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14);
        vsum15 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
        svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv));
        svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv));
        svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv));
        svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv));
        svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv));
        svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv));
        svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv));
        svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
        svst1_f32(svAll, &op[8 * vLen], vsum8);
        svst1_f32(svAll, &op[9 * vLen], vsum9);
        svst1_f32(svAll, &op[10 * vLen], vsum10);
        svst1_f32(svAll, &op[11 * vLen], vsum11);
        svst1_f32(svAll, &op[12 * vLen], vsum12);
        svst1_f32(svAll, &op[13 * vLen], vsum13);
        svst1_f32(svAll, &op[14 * vLen], vsum14);
        svst1_f32(svAll, &op[15 * vLen], vsum15);
      }
    }
  } else if (block_size == 8 * vLen) {
    // unrolling 8 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const float* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0);
        vsum1 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1);
        vsum2 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2);
        vsum3 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3);
        vsum4 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4);
        vsum5 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5);
        vsum6 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6);
        vsum7 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
      }
    }
  } else if (block_size == 4 * vLen) {
    // unrolling 4 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const float* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0);
        vsum1 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1);
        vsum2 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2);
        vsum3 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
      }
    }
  } else if (block_size == 2 * vLen) {
    // unrolling 2 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const float* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0);
        vsum1 =
            svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
      }
    }
  } else {
    // generic code:
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      memset(op, 0, sizeof(float) * block_size);
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const float* ip = &input[idx * block_size];
        svbool_t pg;
        for (int64_t k = 0;
             svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size));
             k += vLen) {
          svst1_f32(
              pg,
              &op[k],
              svmad_f32_x(
                  pg, vwgt, svld1_f32(pg, &ip[k]), svld1_f32(pg, &op[k])));
        }

        ++pos;
      }
      const int64_t length = end_offset - start_offset;

      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svbool_t pg;
        for (int64_t j = 0;
             svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size));
             j += vLen) {
          svst1_f32(
              pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv));
        }
      }
    }
  }
  return pos == index_size;
}
bool EmbeddingLookupIdx_int64_t_float_float_false__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const float* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int64_t_float_float__sve<false>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}
bool EmbeddingLookupIdx_int64_t_float_float_true__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const float* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int64_t_float_float__sve<true>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}

template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int32_t_half_float__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::Half* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  const svbool_t svAll = svptrue_b32();
  const auto vLen = static_cast<int64_t>(svcntw());
  int64_t pos = 0;
  if (block_size == 32 * vLen) {
    // unrolling 32 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      svfloat32_t vsum8 = svdup_n_f32(0);
      svfloat32_t vsum9 = svdup_n_f32(0);
      svfloat32_t vsum10 = svdup_n_f32(0);
      svfloat32_t vsum11 = svdup_n_f32(0);
      svfloat32_t vsum12 = svdup_n_f32(0);
      svfloat32_t vsum13 = svdup_n_f32(0);
      svfloat32_t vsum14 = svdup_n_f32(0);
      svfloat32_t vsum15 = svdup_n_f32(0);
      svfloat32_t vsum16 = svdup_n_f32(0);
      svfloat32_t vsum17 = svdup_n_f32(0);
      svfloat32_t vsum18 = svdup_n_f32(0);
      svfloat32_t vsum19 = svdup_n_f32(0);
      svfloat32_t vsum20 = svdup_n_f32(0);
      svfloat32_t vsum21 = svdup_n_f32(0);
      svfloat32_t vsum22 = svdup_n_f32(0);
      svfloat32_t vsum23 = svdup_n_f32(0);
      svfloat32_t vsum24 = svdup_n_f32(0);
      svfloat32_t vsum25 = svdup_n_f32(0);
      svfloat32_t vsum26 = svdup_n_f32(0);
      svfloat32_t vsum27 = svdup_n_f32(0);
      svfloat32_t vsum28 = svdup_n_f32(0);
      svfloat32_t vsum29 = svdup_n_f32(0);
      svfloat32_t vsum30 = svdup_n_f32(0);
      svfloat32_t vsum31 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::Half* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])))),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])))),
            vsum1);
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[2 * vLen])))),
            vsum2);
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[3 * vLen])))),
            vsum3);
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[4 * vLen])))),
            vsum4);
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[5 * vLen])))),
            vsum5);
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[6 * vLen])))),
            vsum6);
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[7 * vLen])))),
            vsum7);
        vsum8 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[8 * vLen])))),
            vsum8);
        vsum9 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[9 * vLen])))),
            vsum9);
        vsum10 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[10 * vLen])))),
            vsum10);
        vsum11 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[11 * vLen])))),
            vsum11);
        vsum12 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[12 * vLen])))),
            vsum12);
        vsum13 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[13 * vLen])))),
            vsum13);
        vsum14 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[14 * vLen])))),
            vsum14);
        vsum15 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[15 * vLen])))),
            vsum15);
        vsum16 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[16 * vLen])))),
            vsum16);
        vsum17 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[17 * vLen])))),
            vsum17);
        vsum18 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[18 * vLen])))),
            vsum18);
        vsum19 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[19 * vLen])))),
            vsum19);
        vsum20 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[20 * vLen])))),
            vsum20);
        vsum21 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[21 * vLen])))),
            vsum21);
        vsum22 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[22 * vLen])))),
            vsum22);
        vsum23 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[23 * vLen])))),
            vsum23);
        vsum24 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[24 * vLen])))),
            vsum24);
        vsum25 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[25 * vLen])))),
            vsum25);
        vsum26 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[26 * vLen])))),
            vsum26);
        vsum27 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[27 * vLen])))),
            vsum27);
        vsum28 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[28 * vLen])))),
            vsum28);
        vsum29 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[29 * vLen])))),
            vsum29);
        vsum30 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[30 * vLen])))),
            vsum30);
        vsum31 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[31 * vLen])))),
            vsum31);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
        svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv));
        svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv));
        svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv));
        svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv));
        svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv));
        svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv));
        svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv));
        svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv));
        svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv));
        svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv));
        svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv));
        svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv));
        svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv));
        svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv));
        svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv));
        svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv));
        svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv));
        svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv));
        svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv));
        svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv));
        svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv));
        svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv));
        svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv));
        svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
        svst1_f32(svAll, &op[8 * vLen], vsum8);
        svst1_f32(svAll, &op[9 * vLen], vsum9);
        svst1_f32(svAll, &op[10 * vLen], vsum10);
        svst1_f32(svAll, &op[11 * vLen], vsum11);
        svst1_f32(svAll, &op[12 * vLen], vsum12);
        svst1_f32(svAll, &op[13 * vLen], vsum13);
        svst1_f32(svAll, &op[14 * vLen], vsum14);
        svst1_f32(svAll, &op[15 * vLen], vsum15);
        svst1_f32(svAll, &op[16 * vLen], vsum16);
        svst1_f32(svAll, &op[17 * vLen], vsum17);
        svst1_f32(svAll, &op[18 * vLen], vsum18);
        svst1_f32(svAll, &op[19 * vLen], vsum19);
        svst1_f32(svAll, &op[20 * vLen], vsum20);
        svst1_f32(svAll, &op[21 * vLen], vsum21);
        svst1_f32(svAll, &op[22 * vLen], vsum22);
        svst1_f32(svAll, &op[23 * vLen], vsum23);
        svst1_f32(svAll, &op[24 * vLen], vsum24);
        svst1_f32(svAll, &op[25 * vLen], vsum25);
        svst1_f32(svAll, &op[26 * vLen], vsum26);
        svst1_f32(svAll, &op[27 * vLen], vsum27);
        svst1_f32(svAll, &op[28 * vLen], vsum28);
        svst1_f32(svAll, &op[29 * vLen], vsum29);
        svst1_f32(svAll, &op[30 * vLen], vsum30);
        svst1_f32(svAll, &op[31 * vLen], vsum31);
      }
    }
  } else if (block_size == 16 * vLen) {
    // unrolling 16 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      svfloat32_t vsum8 = svdup_n_f32(0);
      svfloat32_t vsum9 = svdup_n_f32(0);
      svfloat32_t vsum10 = svdup_n_f32(0);
      svfloat32_t vsum11 = svdup_n_f32(0);
      svfloat32_t vsum12 = svdup_n_f32(0);
      svfloat32_t vsum13 = svdup_n_f32(0);
      svfloat32_t vsum14 = svdup_n_f32(0);
      svfloat32_t vsum15 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::Half* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])))),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])))),
            vsum1);
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[2 * vLen])))),
            vsum2);
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[3 * vLen])))),
            vsum3);
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[4 * vLen])))),
            vsum4);
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[5 * vLen])))),
            vsum5);
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[6 * vLen])))),
            vsum6);
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[7 * vLen])))),
            vsum7);
        vsum8 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[8 * vLen])))),
            vsum8);
        vsum9 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[9 * vLen])))),
            vsum9);
        vsum10 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[10 * vLen])))),
            vsum10);
        vsum11 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[11 * vLen])))),
            vsum11);
        vsum12 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[12 * vLen])))),
            vsum12);
        vsum13 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[13 * vLen])))),
            vsum13);
        vsum14 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[14 * vLen])))),
            vsum14);
        vsum15 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[15 * vLen])))),
            vsum15);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
        svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv));
        svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv));
        svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv));
        svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv));
        svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv));
        svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv));
        svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv));
        svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
        svst1_f32(svAll, &op[8 * vLen], vsum8);
        svst1_f32(svAll, &op[9 * vLen], vsum9);
        svst1_f32(svAll, &op[10 * vLen], vsum10);
        svst1_f32(svAll, &op[11 * vLen], vsum11);
        svst1_f32(svAll, &op[12 * vLen], vsum12);
        svst1_f32(svAll, &op[13 * vLen], vsum13);
        svst1_f32(svAll, &op[14 * vLen], vsum14);
        svst1_f32(svAll, &op[15 * vLen], vsum15);
      }
    }
  } else if (block_size == 8 * vLen) {
    // unrolling 8 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::Half* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])))),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])))),
            vsum1);
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[2 * vLen])))),
            vsum2);
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[3 * vLen])))),
            vsum3);
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[4 * vLen])))),
            vsum4);
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[5 * vLen])))),
            vsum5);
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[6 * vLen])))),
            vsum6);
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[7 * vLen])))),
            vsum7);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
      }
    }
  } else if (block_size == 4 * vLen) {
    // unrolling 4 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::Half* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])))),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])))),
            vsum1);
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[2 * vLen])))),
            vsum2);
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[3 * vLen])))),
            vsum3);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
      }
    }
  } else if (block_size == 2 * vLen) {
    // unrolling 2 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::Half* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])))),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])))),
            vsum1);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
      }
    }
  } else {
    // generic code:
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      memset(op, 0, sizeof(float) * block_size);
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::Half* ip = &input[idx * block_size];
        svbool_t pg;
        for (int64_t k = 0;
             svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size));
             k += vLen) {
          svst1_f32(
              pg,
              &op[k],
              svmad_f32_x(
                  pg,
                  vwgt,
                  svcvt_f32_f16_x(
                      pg,
                      svreinterpret_f16_u32(svld1uh_u32(
                          pg, reinterpret_cast<const uint16_t*>(&ip[k])))),
                  svld1_f32(pg, &op[k])));
        }

        ++pos;
      }
      const int64_t length = end_offset - start_offset;

      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svbool_t pg;
        for (int64_t j = 0;
             svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size));
             j += vLen) {
          svst1_f32(
              pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv));
        }
      }
    }
  }
  return pos == index_size;
}
bool EmbeddingLookupIdx_int32_t_half_float_false__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::Half* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int32_t_half_float__sve<false>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}
bool EmbeddingLookupIdx_int32_t_half_float_true__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::Half* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int32_t_half_float__sve<true>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}

template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int64_t_half_float__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::Half* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  const svbool_t svAll = svptrue_b32();
  const auto vLen = static_cast<int64_t>(svcntw());
  int64_t pos = 0;
  if (block_size == 32 * vLen) {
    // unrolling 32 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      svfloat32_t vsum8 = svdup_n_f32(0);
      svfloat32_t vsum9 = svdup_n_f32(0);
      svfloat32_t vsum10 = svdup_n_f32(0);
      svfloat32_t vsum11 = svdup_n_f32(0);
      svfloat32_t vsum12 = svdup_n_f32(0);
      svfloat32_t vsum13 = svdup_n_f32(0);
      svfloat32_t vsum14 = svdup_n_f32(0);
      svfloat32_t vsum15 = svdup_n_f32(0);
      svfloat32_t vsum16 = svdup_n_f32(0);
      svfloat32_t vsum17 = svdup_n_f32(0);
      svfloat32_t vsum18 = svdup_n_f32(0);
      svfloat32_t vsum19 = svdup_n_f32(0);
      svfloat32_t vsum20 = svdup_n_f32(0);
      svfloat32_t vsum21 = svdup_n_f32(0);
      svfloat32_t vsum22 = svdup_n_f32(0);
      svfloat32_t vsum23 = svdup_n_f32(0);
      svfloat32_t vsum24 = svdup_n_f32(0);
      svfloat32_t vsum25 = svdup_n_f32(0);
      svfloat32_t vsum26 = svdup_n_f32(0);
      svfloat32_t vsum27 = svdup_n_f32(0);
      svfloat32_t vsum28 = svdup_n_f32(0);
      svfloat32_t vsum29 = svdup_n_f32(0);
      svfloat32_t vsum30 = svdup_n_f32(0);
      svfloat32_t vsum31 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::Half* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])))),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])))),
            vsum1);
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[2 * vLen])))),
            vsum2);
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[3 * vLen])))),
            vsum3);
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[4 * vLen])))),
            vsum4);
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[5 * vLen])))),
            vsum5);
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[6 * vLen])))),
            vsum6);
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[7 * vLen])))),
            vsum7);
        vsum8 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[8 * vLen])))),
            vsum8);
        vsum9 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[9 * vLen])))),
            vsum9);
        vsum10 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[10 * vLen])))),
            vsum10);
        vsum11 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[11 * vLen])))),
            vsum11);
        vsum12 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[12 * vLen])))),
            vsum12);
        vsum13 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[13 * vLen])))),
            vsum13);
        vsum14 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[14 * vLen])))),
            vsum14);
        vsum15 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[15 * vLen])))),
            vsum15);
        vsum16 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[16 * vLen])))),
            vsum16);
        vsum17 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[17 * vLen])))),
            vsum17);
        vsum18 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[18 * vLen])))),
            vsum18);
        vsum19 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[19 * vLen])))),
            vsum19);
        vsum20 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[20 * vLen])))),
            vsum20);
        vsum21 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[21 * vLen])))),
            vsum21);
        vsum22 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[22 * vLen])))),
            vsum22);
        vsum23 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[23 * vLen])))),
            vsum23);
        vsum24 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[24 * vLen])))),
            vsum24);
        vsum25 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[25 * vLen])))),
            vsum25);
        vsum26 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[26 * vLen])))),
            vsum26);
        vsum27 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[27 * vLen])))),
            vsum27);
        vsum28 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[28 * vLen])))),
            vsum28);
        vsum29 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[29 * vLen])))),
            vsum29);
        vsum30 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[30 * vLen])))),
            vsum30);
        vsum31 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[31 * vLen])))),
            vsum31);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
        svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv));
        svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv));
        svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv));
        svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv));
        svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv));
        svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv));
        svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv));
        svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv));
        svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv));
        svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv));
        svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv));
        svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv));
        svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv));
        svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv));
        svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv));
        svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv));
        svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv));
        svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv));
        svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv));
        svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv));
        svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv));
        svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv));
        svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv));
        svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
        svst1_f32(svAll, &op[8 * vLen], vsum8);
        svst1_f32(svAll, &op[9 * vLen], vsum9);
        svst1_f32(svAll, &op[10 * vLen], vsum10);
        svst1_f32(svAll, &op[11 * vLen], vsum11);
        svst1_f32(svAll, &op[12 * vLen], vsum12);
        svst1_f32(svAll, &op[13 * vLen], vsum13);
        svst1_f32(svAll, &op[14 * vLen], vsum14);
        svst1_f32(svAll, &op[15 * vLen], vsum15);
        svst1_f32(svAll, &op[16 * vLen], vsum16);
        svst1_f32(svAll, &op[17 * vLen], vsum17);
        svst1_f32(svAll, &op[18 * vLen], vsum18);
        svst1_f32(svAll, &op[19 * vLen], vsum19);
        svst1_f32(svAll, &op[20 * vLen], vsum20);
        svst1_f32(svAll, &op[21 * vLen], vsum21);
        svst1_f32(svAll, &op[22 * vLen], vsum22);
        svst1_f32(svAll, &op[23 * vLen], vsum23);
        svst1_f32(svAll, &op[24 * vLen], vsum24);
        svst1_f32(svAll, &op[25 * vLen], vsum25);
        svst1_f32(svAll, &op[26 * vLen], vsum26);
        svst1_f32(svAll, &op[27 * vLen], vsum27);
        svst1_f32(svAll, &op[28 * vLen], vsum28);
        svst1_f32(svAll, &op[29 * vLen], vsum29);
        svst1_f32(svAll, &op[30 * vLen], vsum30);
        svst1_f32(svAll, &op[31 * vLen], vsum31);
      }
    }
  } else if (block_size == 16 * vLen) {
    // unrolling 16 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      svfloat32_t vsum8 = svdup_n_f32(0);
      svfloat32_t vsum9 = svdup_n_f32(0);
      svfloat32_t vsum10 = svdup_n_f32(0);
      svfloat32_t vsum11 = svdup_n_f32(0);
      svfloat32_t vsum12 = svdup_n_f32(0);
      svfloat32_t vsum13 = svdup_n_f32(0);
      svfloat32_t vsum14 = svdup_n_f32(0);
      svfloat32_t vsum15 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::Half* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])))),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])))),
            vsum1);
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[2 * vLen])))),
            vsum2);
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[3 * vLen])))),
            vsum3);
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[4 * vLen])))),
            vsum4);
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[5 * vLen])))),
            vsum5);
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[6 * vLen])))),
            vsum6);
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[7 * vLen])))),
            vsum7);
        vsum8 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[8 * vLen])))),
            vsum8);
        vsum9 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[9 * vLen])))),
            vsum9);
        vsum10 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[10 * vLen])))),
            vsum10);
        vsum11 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[11 * vLen])))),
            vsum11);
        vsum12 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[12 * vLen])))),
            vsum12);
        vsum13 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[13 * vLen])))),
            vsum13);
        vsum14 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[14 * vLen])))),
            vsum14);
        vsum15 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[15 * vLen])))),
            vsum15);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
        svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv));
        svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv));
        svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv));
        svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv));
        svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv));
        svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv));
        svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv));
        svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
        svst1_f32(svAll, &op[8 * vLen], vsum8);
        svst1_f32(svAll, &op[9 * vLen], vsum9);
        svst1_f32(svAll, &op[10 * vLen], vsum10);
        svst1_f32(svAll, &op[11 * vLen], vsum11);
        svst1_f32(svAll, &op[12 * vLen], vsum12);
        svst1_f32(svAll, &op[13 * vLen], vsum13);
        svst1_f32(svAll, &op[14 * vLen], vsum14);
        svst1_f32(svAll, &op[15 * vLen], vsum15);
      }
    }
  } else if (block_size == 8 * vLen) {
    // unrolling 8 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::Half* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])))),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])))),
            vsum1);
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[2 * vLen])))),
            vsum2);
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[3 * vLen])))),
            vsum3);
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[4 * vLen])))),
            vsum4);
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[5 * vLen])))),
            vsum5);
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[6 * vLen])))),
            vsum6);
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[7 * vLen])))),
            vsum7);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
      }
    }
  } else if (block_size == 4 * vLen) {
    // unrolling 4 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::Half* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])))),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])))),
            vsum1);
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[2 * vLen])))),
            vsum2);
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[3 * vLen])))),
            vsum3);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
      }
    }
  } else if (block_size == 2 * vLen) {
    // unrolling 2 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::Half* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])))),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_f16_x(
                svAll,
                svreinterpret_f16_u32(svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])))),
            vsum1);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
      }
    }
  } else {
    // generic code:
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      memset(op, 0, sizeof(float) * block_size);
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::Half* ip = &input[idx * block_size];
        svbool_t pg;
        for (int64_t k = 0;
             svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size));
             k += vLen) {
          svst1_f32(
              pg,
              &op[k],
              svmad_f32_x(
                  pg,
                  vwgt,
                  svcvt_f32_f16_x(
                      pg,
                      svreinterpret_f16_u32(svld1uh_u32(
                          pg, reinterpret_cast<const uint16_t*>(&ip[k])))),
                  svld1_f32(pg, &op[k])));
        }

        ++pos;
      }
      const int64_t length = end_offset - start_offset;

      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svbool_t pg;
        for (int64_t j = 0;
             svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size));
             j += vLen) {
          svst1_f32(
              pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv));
        }
      }
    }
  }
  return pos == index_size;
}
bool EmbeddingLookupIdx_int64_t_half_float_false__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::Half* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int64_t_half_float__sve<false>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}
bool EmbeddingLookupIdx_int64_t_half_float_true__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::Half* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int64_t_half_float__sve<true>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}

template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int32_t_bfloat16_float__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::BFloat16* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  const svbool_t svAll = svptrue_b32();
  const auto vLen = static_cast<int64_t>(svcntw());
  int64_t pos = 0;
  if (block_size == 32 * vLen) {
    // unrolling 32 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      svfloat32_t vsum8 = svdup_n_f32(0);
      svfloat32_t vsum9 = svdup_n_f32(0);
      svfloat32_t vsum10 = svdup_n_f32(0);
      svfloat32_t vsum11 = svdup_n_f32(0);
      svfloat32_t vsum12 = svdup_n_f32(0);
      svfloat32_t vsum13 = svdup_n_f32(0);
      svfloat32_t vsum14 = svdup_n_f32(0);
      svfloat32_t vsum15 = svdup_n_f32(0);
      svfloat32_t vsum16 = svdup_n_f32(0);
      svfloat32_t vsum17 = svdup_n_f32(0);
      svfloat32_t vsum18 = svdup_n_f32(0);
      svfloat32_t vsum19 = svdup_n_f32(0);
      svfloat32_t vsum20 = svdup_n_f32(0);
      svfloat32_t vsum21 = svdup_n_f32(0);
      svfloat32_t vsum22 = svdup_n_f32(0);
      svfloat32_t vsum23 = svdup_n_f32(0);
      svfloat32_t vsum24 = svdup_n_f32(0);
      svfloat32_t vsum25 = svdup_n_f32(0);
      svfloat32_t vsum26 = svdup_n_f32(0);
      svfloat32_t vsum27 = svdup_n_f32(0);
      svfloat32_t vsum28 = svdup_n_f32(0);
      svfloat32_t vsum29 = svdup_n_f32(0);
      svfloat32_t vsum30 = svdup_n_f32(0);
      svfloat32_t vsum31 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::BFloat16* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])),
                16)),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])),
                16)),
            vsum1);
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[2 * vLen])),
                16)),
            vsum2);
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[3 * vLen])),
                16)),
            vsum3);
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[4 * vLen])),
                16)),
            vsum4);
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[5 * vLen])),
                16)),
            vsum5);
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[6 * vLen])),
                16)),
            vsum6);
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[7 * vLen])),
                16)),
            vsum7);
        vsum8 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[8 * vLen])),
                16)),
            vsum8);
        vsum9 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[9 * vLen])),
                16)),
            vsum9);
        vsum10 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[10 * vLen])),
                16)),
            vsum10);
        vsum11 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[11 * vLen])),
                16)),
            vsum11);
        vsum12 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[12 * vLen])),
                16)),
            vsum12);
        vsum13 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[13 * vLen])),
                16)),
            vsum13);
        vsum14 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[14 * vLen])),
                16)),
            vsum14);
        vsum15 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[15 * vLen])),
                16)),
            vsum15);
        vsum16 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[16 * vLen])),
                16)),
            vsum16);
        vsum17 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[17 * vLen])),
                16)),
            vsum17);
        vsum18 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[18 * vLen])),
                16)),
            vsum18);
        vsum19 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[19 * vLen])),
                16)),
            vsum19);
        vsum20 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[20 * vLen])),
                16)),
            vsum20);
        vsum21 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[21 * vLen])),
                16)),
            vsum21);
        vsum22 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[22 * vLen])),
                16)),
            vsum22);
        vsum23 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[23 * vLen])),
                16)),
            vsum23);
        vsum24 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[24 * vLen])),
                16)),
            vsum24);
        vsum25 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[25 * vLen])),
                16)),
            vsum25);
        vsum26 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[26 * vLen])),
                16)),
            vsum26);
        vsum27 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[27 * vLen])),
                16)),
            vsum27);
        vsum28 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[28 * vLen])),
                16)),
            vsum28);
        vsum29 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[29 * vLen])),
                16)),
            vsum29);
        vsum30 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[30 * vLen])),
                16)),
            vsum30);
        vsum31 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[31 * vLen])),
                16)),
            vsum31);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
        svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv));
        svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv));
        svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv));
        svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv));
        svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv));
        svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv));
        svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv));
        svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv));
        svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv));
        svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv));
        svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv));
        svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv));
        svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv));
        svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv));
        svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv));
        svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv));
        svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv));
        svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv));
        svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv));
        svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv));
        svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv));
        svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv));
        svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv));
        svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
        svst1_f32(svAll, &op[8 * vLen], vsum8);
        svst1_f32(svAll, &op[9 * vLen], vsum9);
        svst1_f32(svAll, &op[10 * vLen], vsum10);
        svst1_f32(svAll, &op[11 * vLen], vsum11);
        svst1_f32(svAll, &op[12 * vLen], vsum12);
        svst1_f32(svAll, &op[13 * vLen], vsum13);
        svst1_f32(svAll, &op[14 * vLen], vsum14);
        svst1_f32(svAll, &op[15 * vLen], vsum15);
        svst1_f32(svAll, &op[16 * vLen], vsum16);
        svst1_f32(svAll, &op[17 * vLen], vsum17);
        svst1_f32(svAll, &op[18 * vLen], vsum18);
        svst1_f32(svAll, &op[19 * vLen], vsum19);
        svst1_f32(svAll, &op[20 * vLen], vsum20);
        svst1_f32(svAll, &op[21 * vLen], vsum21);
        svst1_f32(svAll, &op[22 * vLen], vsum22);
        svst1_f32(svAll, &op[23 * vLen], vsum23);
        svst1_f32(svAll, &op[24 * vLen], vsum24);
        svst1_f32(svAll, &op[25 * vLen], vsum25);
        svst1_f32(svAll, &op[26 * vLen], vsum26);
        svst1_f32(svAll, &op[27 * vLen], vsum27);
        svst1_f32(svAll, &op[28 * vLen], vsum28);
        svst1_f32(svAll, &op[29 * vLen], vsum29);
        svst1_f32(svAll, &op[30 * vLen], vsum30);
        svst1_f32(svAll, &op[31 * vLen], vsum31);
      }
    }
  } else if (block_size == 16 * vLen) {
    // unrolling 16 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      svfloat32_t vsum8 = svdup_n_f32(0);
      svfloat32_t vsum9 = svdup_n_f32(0);
      svfloat32_t vsum10 = svdup_n_f32(0);
      svfloat32_t vsum11 = svdup_n_f32(0);
      svfloat32_t vsum12 = svdup_n_f32(0);
      svfloat32_t vsum13 = svdup_n_f32(0);
      svfloat32_t vsum14 = svdup_n_f32(0);
      svfloat32_t vsum15 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::BFloat16* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])),
                16)),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])),
                16)),
            vsum1);
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[2 * vLen])),
                16)),
            vsum2);
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[3 * vLen])),
                16)),
            vsum3);
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[4 * vLen])),
                16)),
            vsum4);
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[5 * vLen])),
                16)),
            vsum5);
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[6 * vLen])),
                16)),
            vsum6);
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[7 * vLen])),
                16)),
            vsum7);
        vsum8 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[8 * vLen])),
                16)),
            vsum8);
        vsum9 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[9 * vLen])),
                16)),
            vsum9);
        vsum10 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[10 * vLen])),
                16)),
            vsum10);
        vsum11 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[11 * vLen])),
                16)),
            vsum11);
        vsum12 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[12 * vLen])),
                16)),
            vsum12);
        vsum13 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[13 * vLen])),
                16)),
            vsum13);
        vsum14 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[14 * vLen])),
                16)),
            vsum14);
        vsum15 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[15 * vLen])),
                16)),
            vsum15);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
        svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv));
        svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv));
        svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv));
        svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv));
        svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv));
        svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv));
        svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv));
        svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
        svst1_f32(svAll, &op[8 * vLen], vsum8);
        svst1_f32(svAll, &op[9 * vLen], vsum9);
        svst1_f32(svAll, &op[10 * vLen], vsum10);
        svst1_f32(svAll, &op[11 * vLen], vsum11);
        svst1_f32(svAll, &op[12 * vLen], vsum12);
        svst1_f32(svAll, &op[13 * vLen], vsum13);
        svst1_f32(svAll, &op[14 * vLen], vsum14);
        svst1_f32(svAll, &op[15 * vLen], vsum15);
      }
    }
  } else if (block_size == 8 * vLen) {
    // unrolling 8 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::BFloat16* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])),
                16)),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])),
                16)),
            vsum1);
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[2 * vLen])),
                16)),
            vsum2);
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[3 * vLen])),
                16)),
            vsum3);
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[4 * vLen])),
                16)),
            vsum4);
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[5 * vLen])),
                16)),
            vsum5);
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[6 * vLen])),
                16)),
            vsum6);
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[7 * vLen])),
                16)),
            vsum7);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
      }
    }
  } else if (block_size == 4 * vLen) {
    // unrolling 4 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::BFloat16* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])),
                16)),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])),
                16)),
            vsum1);
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[2 * vLen])),
                16)),
            vsum2);
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[3 * vLen])),
                16)),
            vsum3);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
      }
    }
  } else if (block_size == 2 * vLen) {
    // unrolling 2 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::BFloat16* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])),
                16)),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])),
                16)),
            vsum1);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
      }
    }
  } else {
    // generic code:
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      memset(op, 0, sizeof(float) * block_size);
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::BFloat16* ip = &input[idx * block_size];
        svbool_t pg;
        for (int64_t k = 0;
             svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size));
             k += vLen) {
          svst1_f32(
              pg,
              &op[k],
              svmad_f32_x(
                  pg,
                  vwgt,
                  svreinterpret_f32_u32(svlsl_n_u32_x(
                      pg,
                      svld1uh_u32(
                          pg, reinterpret_cast<const uint16_t*>(&ip[k])),
                      16)),
                  svld1_f32(pg, &op[k])));
        }

        ++pos;
      }
      const int64_t length = end_offset - start_offset;

      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svbool_t pg;
        for (int64_t j = 0;
             svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size));
             j += vLen) {
          svst1_f32(
              pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv));
        }
      }
    }
  }
  return pos == index_size;
}
bool EmbeddingLookupIdx_int32_t_bfloat16_float_false__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::BFloat16* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int32_t_bfloat16_float__sve<false>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}
bool EmbeddingLookupIdx_int32_t_bfloat16_float_true__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::BFloat16* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int32_t_bfloat16_float__sve<true>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}

template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int64_t_bfloat16_float__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::BFloat16* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  const svbool_t svAll = svptrue_b32();
  const auto vLen = static_cast<int64_t>(svcntw());
  int64_t pos = 0;
  if (block_size == 32 * vLen) {
    // unrolling 32 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      svfloat32_t vsum8 = svdup_n_f32(0);
      svfloat32_t vsum9 = svdup_n_f32(0);
      svfloat32_t vsum10 = svdup_n_f32(0);
      svfloat32_t vsum11 = svdup_n_f32(0);
      svfloat32_t vsum12 = svdup_n_f32(0);
      svfloat32_t vsum13 = svdup_n_f32(0);
      svfloat32_t vsum14 = svdup_n_f32(0);
      svfloat32_t vsum15 = svdup_n_f32(0);
      svfloat32_t vsum16 = svdup_n_f32(0);
      svfloat32_t vsum17 = svdup_n_f32(0);
      svfloat32_t vsum18 = svdup_n_f32(0);
      svfloat32_t vsum19 = svdup_n_f32(0);
      svfloat32_t vsum20 = svdup_n_f32(0);
      svfloat32_t vsum21 = svdup_n_f32(0);
      svfloat32_t vsum22 = svdup_n_f32(0);
      svfloat32_t vsum23 = svdup_n_f32(0);
      svfloat32_t vsum24 = svdup_n_f32(0);
      svfloat32_t vsum25 = svdup_n_f32(0);
      svfloat32_t vsum26 = svdup_n_f32(0);
      svfloat32_t vsum27 = svdup_n_f32(0);
      svfloat32_t vsum28 = svdup_n_f32(0);
      svfloat32_t vsum29 = svdup_n_f32(0);
      svfloat32_t vsum30 = svdup_n_f32(0);
      svfloat32_t vsum31 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::BFloat16* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])),
                16)),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])),
                16)),
            vsum1);
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[2 * vLen])),
                16)),
            vsum2);
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[3 * vLen])),
                16)),
            vsum3);
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[4 * vLen])),
                16)),
            vsum4);
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[5 * vLen])),
                16)),
            vsum5);
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[6 * vLen])),
                16)),
            vsum6);
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[7 * vLen])),
                16)),
            vsum7);
        vsum8 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[8 * vLen])),
                16)),
            vsum8);
        vsum9 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[9 * vLen])),
                16)),
            vsum9);
        vsum10 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[10 * vLen])),
                16)),
            vsum10);
        vsum11 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[11 * vLen])),
                16)),
            vsum11);
        vsum12 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[12 * vLen])),
                16)),
            vsum12);
        vsum13 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[13 * vLen])),
                16)),
            vsum13);
        vsum14 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[14 * vLen])),
                16)),
            vsum14);
        vsum15 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[15 * vLen])),
                16)),
            vsum15);
        vsum16 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[16 * vLen])),
                16)),
            vsum16);
        vsum17 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[17 * vLen])),
                16)),
            vsum17);
        vsum18 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[18 * vLen])),
                16)),
            vsum18);
        vsum19 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[19 * vLen])),
                16)),
            vsum19);
        vsum20 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[20 * vLen])),
                16)),
            vsum20);
        vsum21 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[21 * vLen])),
                16)),
            vsum21);
        vsum22 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[22 * vLen])),
                16)),
            vsum22);
        vsum23 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[23 * vLen])),
                16)),
            vsum23);
        vsum24 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[24 * vLen])),
                16)),
            vsum24);
        vsum25 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[25 * vLen])),
                16)),
            vsum25);
        vsum26 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[26 * vLen])),
                16)),
            vsum26);
        vsum27 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[27 * vLen])),
                16)),
            vsum27);
        vsum28 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[28 * vLen])),
                16)),
            vsum28);
        vsum29 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[29 * vLen])),
                16)),
            vsum29);
        vsum30 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[30 * vLen])),
                16)),
            vsum30);
        vsum31 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[31 * vLen])),
                16)),
            vsum31);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
        svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv));
        svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv));
        svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv));
        svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv));
        svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv));
        svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv));
        svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv));
        svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv));
        svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv));
        svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv));
        svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv));
        svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv));
        svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv));
        svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv));
        svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv));
        svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv));
        svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv));
        svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv));
        svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv));
        svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv));
        svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv));
        svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv));
        svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv));
        svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
        svst1_f32(svAll, &op[8 * vLen], vsum8);
        svst1_f32(svAll, &op[9 * vLen], vsum9);
        svst1_f32(svAll, &op[10 * vLen], vsum10);
        svst1_f32(svAll, &op[11 * vLen], vsum11);
        svst1_f32(svAll, &op[12 * vLen], vsum12);
        svst1_f32(svAll, &op[13 * vLen], vsum13);
        svst1_f32(svAll, &op[14 * vLen], vsum14);
        svst1_f32(svAll, &op[15 * vLen], vsum15);
        svst1_f32(svAll, &op[16 * vLen], vsum16);
        svst1_f32(svAll, &op[17 * vLen], vsum17);
        svst1_f32(svAll, &op[18 * vLen], vsum18);
        svst1_f32(svAll, &op[19 * vLen], vsum19);
        svst1_f32(svAll, &op[20 * vLen], vsum20);
        svst1_f32(svAll, &op[21 * vLen], vsum21);
        svst1_f32(svAll, &op[22 * vLen], vsum22);
        svst1_f32(svAll, &op[23 * vLen], vsum23);
        svst1_f32(svAll, &op[24 * vLen], vsum24);
        svst1_f32(svAll, &op[25 * vLen], vsum25);
        svst1_f32(svAll, &op[26 * vLen], vsum26);
        svst1_f32(svAll, &op[27 * vLen], vsum27);
        svst1_f32(svAll, &op[28 * vLen], vsum28);
        svst1_f32(svAll, &op[29 * vLen], vsum29);
        svst1_f32(svAll, &op[30 * vLen], vsum30);
        svst1_f32(svAll, &op[31 * vLen], vsum31);
      }
    }
  } else if (block_size == 16 * vLen) {
    // unrolling 16 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      svfloat32_t vsum8 = svdup_n_f32(0);
      svfloat32_t vsum9 = svdup_n_f32(0);
      svfloat32_t vsum10 = svdup_n_f32(0);
      svfloat32_t vsum11 = svdup_n_f32(0);
      svfloat32_t vsum12 = svdup_n_f32(0);
      svfloat32_t vsum13 = svdup_n_f32(0);
      svfloat32_t vsum14 = svdup_n_f32(0);
      svfloat32_t vsum15 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::BFloat16* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])),
                16)),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])),
                16)),
            vsum1);
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[2 * vLen])),
                16)),
            vsum2);
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[3 * vLen])),
                16)),
            vsum3);
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[4 * vLen])),
                16)),
            vsum4);
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[5 * vLen])),
                16)),
            vsum5);
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[6 * vLen])),
                16)),
            vsum6);
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[7 * vLen])),
                16)),
            vsum7);
        vsum8 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[8 * vLen])),
                16)),
            vsum8);
        vsum9 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[9 * vLen])),
                16)),
            vsum9);
        vsum10 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[10 * vLen])),
                16)),
            vsum10);
        vsum11 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[11 * vLen])),
                16)),
            vsum11);
        vsum12 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[12 * vLen])),
                16)),
            vsum12);
        vsum13 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[13 * vLen])),
                16)),
            vsum13);
        vsum14 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[14 * vLen])),
                16)),
            vsum14);
        vsum15 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[15 * vLen])),
                16)),
            vsum15);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
        svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv));
        svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv));
        svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv));
        svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv));
        svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv));
        svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv));
        svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv));
        svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
        svst1_f32(svAll, &op[8 * vLen], vsum8);
        svst1_f32(svAll, &op[9 * vLen], vsum9);
        svst1_f32(svAll, &op[10 * vLen], vsum10);
        svst1_f32(svAll, &op[11 * vLen], vsum11);
        svst1_f32(svAll, &op[12 * vLen], vsum12);
        svst1_f32(svAll, &op[13 * vLen], vsum13);
        svst1_f32(svAll, &op[14 * vLen], vsum14);
        svst1_f32(svAll, &op[15 * vLen], vsum15);
      }
    }
  } else if (block_size == 8 * vLen) {
    // unrolling 8 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::BFloat16* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])),
                16)),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])),
                16)),
            vsum1);
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[2 * vLen])),
                16)),
            vsum2);
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[3 * vLen])),
                16)),
            vsum3);
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[4 * vLen])),
                16)),
            vsum4);
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[5 * vLen])),
                16)),
            vsum5);
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[6 * vLen])),
                16)),
            vsum6);
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[7 * vLen])),
                16)),
            vsum7);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
      }
    }
  } else if (block_size == 4 * vLen) {
    // unrolling 4 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::BFloat16* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])),
                16)),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])),
                16)),
            vsum1);
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[2 * vLen])),
                16)),
            vsum2);
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[3 * vLen])),
                16)),
            vsum3);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
      }
    }
  } else if (block_size == 2 * vLen) {
    // unrolling 2 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::BFloat16* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[0 * vLen])),
                16)),
            vsum0);
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svreinterpret_f32_u32(svlsl_n_u32_x(
                svAll,
                svld1uh_u32(
                    svAll, reinterpret_cast<const uint16_t*>(&ip[1 * vLen])),
                16)),
            vsum1);
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
      }
    }
  } else {
    // generic code:
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      memset(op, 0, sizeof(float) * block_size);
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const at::BFloat16* ip = &input[idx * block_size];
        svbool_t pg;
        for (int64_t k = 0;
             svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size));
             k += vLen) {
          svst1_f32(
              pg,
              &op[k],
              svmad_f32_x(
                  pg,
                  vwgt,
                  svreinterpret_f32_u32(svlsl_n_u32_x(
                      pg,
                      svld1uh_u32(
                          pg, reinterpret_cast<const uint16_t*>(&ip[k])),
                      16)),
                  svld1_f32(pg, &op[k])));
        }

        ++pos;
      }
      const int64_t length = end_offset - start_offset;

      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svbool_t pg;
        for (int64_t j = 0;
             svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size));
             j += vLen) {
          svst1_f32(
              pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv));
        }
      }
    }
  }
  return pos == index_size;
}
bool EmbeddingLookupIdx_int64_t_bfloat16_float_false__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::BFloat16* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int64_t_bfloat16_float__sve<false>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}
bool EmbeddingLookupIdx_int64_t_bfloat16_float_true__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::BFloat16* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int64_t_bfloat16_float__sve<true>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}

template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int32_t_uint8_t_float__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const uint8_t* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  const svbool_t svAll = svptrue_b32();
  const auto vLen = static_cast<int64_t>(svcntw());
  int64_t pos = 0;
  if (block_size == 32 * vLen) {
    // unrolling 32 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      svfloat32_t vsum8 = svdup_n_f32(0);
      svfloat32_t vsum9 = svdup_n_f32(0);
      svfloat32_t vsum10 = svdup_n_f32(0);
      svfloat32_t vsum11 = svdup_n_f32(0);
      svfloat32_t vsum12 = svdup_n_f32(0);
      svfloat32_t vsum13 = svdup_n_f32(0);
      svfloat32_t vsum14 = svdup_n_f32(0);
      svfloat32_t vsum15 = svdup_n_f32(0);
      svfloat32_t vsum16 = svdup_n_f32(0);
      svfloat32_t vsum17 = svdup_n_f32(0);
      svfloat32_t vsum18 = svdup_n_f32(0);
      svfloat32_t vsum19 = svdup_n_f32(0);
      svfloat32_t vsum20 = svdup_n_f32(0);
      svfloat32_t vsum21 = svdup_n_f32(0);
      svfloat32_t vsum22 = svdup_n_f32(0);
      svfloat32_t vsum23 = svdup_n_f32(0);
      svfloat32_t vsum24 = svdup_n_f32(0);
      svfloat32_t vsum25 = svdup_n_f32(0);
      svfloat32_t vsum26 = svdup_n_f32(0);
      svfloat32_t vsum27 = svdup_n_f32(0);
      svfloat32_t vsum28 = svdup_n_f32(0);
      svfloat32_t vsum29 = svdup_n_f32(0);
      svfloat32_t vsum30 = svdup_n_f32(0);
      svfloat32_t vsum31 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        float bio{};
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        if (scale_bias) {
          bio = wgt * scale_bias[2 * idx + 1];
          wgt = wgt * scale_bias[2 * idx];
        }
        svfloat32_t vbio = svdup_n_f32(bio);
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const uint8_t* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])),
            svadd_f32_x(svAll, vsum0, vbio));
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])),
            svadd_f32_x(svAll, vsum1, vbio));
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])),
            svadd_f32_x(svAll, vsum2, vbio));
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])),
            svadd_f32_x(svAll, vsum3, vbio));
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])),
            svadd_f32_x(svAll, vsum4, vbio));
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])),
            svadd_f32_x(svAll, vsum5, vbio));
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])),
            svadd_f32_x(svAll, vsum6, vbio));
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])),
            svadd_f32_x(svAll, vsum7, vbio));
        vsum8 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])),
            svadd_f32_x(svAll, vsum8, vbio));
        vsum9 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])),
            svadd_f32_x(svAll, vsum9, vbio));
        vsum10 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])),
            svadd_f32_x(svAll, vsum10, vbio));
        vsum11 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])),
            svadd_f32_x(svAll, vsum11, vbio));
        vsum12 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])),
            svadd_f32_x(svAll, vsum12, vbio));
        vsum13 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])),
            svadd_f32_x(svAll, vsum13, vbio));
        vsum14 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])),
            svadd_f32_x(svAll, vsum14, vbio));
        vsum15 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])),
            svadd_f32_x(svAll, vsum15, vbio));
        vsum16 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[16 * vLen])),
            svadd_f32_x(svAll, vsum16, vbio));
        vsum17 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[17 * vLen])),
            svadd_f32_x(svAll, vsum17, vbio));
        vsum18 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[18 * vLen])),
            svadd_f32_x(svAll, vsum18, vbio));
        vsum19 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[19 * vLen])),
            svadd_f32_x(svAll, vsum19, vbio));
        vsum20 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[20 * vLen])),
            svadd_f32_x(svAll, vsum20, vbio));
        vsum21 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[21 * vLen])),
            svadd_f32_x(svAll, vsum21, vbio));
        vsum22 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[22 * vLen])),
            svadd_f32_x(svAll, vsum22, vbio));
        vsum23 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[23 * vLen])),
            svadd_f32_x(svAll, vsum23, vbio));
        vsum24 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[24 * vLen])),
            svadd_f32_x(svAll, vsum24, vbio));
        vsum25 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[25 * vLen])),
            svadd_f32_x(svAll, vsum25, vbio));
        vsum26 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[26 * vLen])),
            svadd_f32_x(svAll, vsum26, vbio));
        vsum27 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[27 * vLen])),
            svadd_f32_x(svAll, vsum27, vbio));
        vsum28 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[28 * vLen])),
            svadd_f32_x(svAll, vsum28, vbio));
        vsum29 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[29 * vLen])),
            svadd_f32_x(svAll, vsum29, vbio));
        vsum30 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[30 * vLen])),
            svadd_f32_x(svAll, vsum30, vbio));
        vsum31 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[31 * vLen])),
            svadd_f32_x(svAll, vsum31, vbio));
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
        svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv));
        svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv));
        svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv));
        svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv));
        svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv));
        svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv));
        svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv));
        svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv));
        svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv));
        svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv));
        svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv));
        svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv));
        svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv));
        svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv));
        svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv));
        svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv));
        svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv));
        svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv));
        svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv));
        svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv));
        svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv));
        svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv));
        svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv));
        svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
        svst1_f32(svAll, &op[8 * vLen], vsum8);
        svst1_f32(svAll, &op[9 * vLen], vsum9);
        svst1_f32(svAll, &op[10 * vLen], vsum10);
        svst1_f32(svAll, &op[11 * vLen], vsum11);
        svst1_f32(svAll, &op[12 * vLen], vsum12);
        svst1_f32(svAll, &op[13 * vLen], vsum13);
        svst1_f32(svAll, &op[14 * vLen], vsum14);
        svst1_f32(svAll, &op[15 * vLen], vsum15);
        svst1_f32(svAll, &op[16 * vLen], vsum16);
        svst1_f32(svAll, &op[17 * vLen], vsum17);
        svst1_f32(svAll, &op[18 * vLen], vsum18);
        svst1_f32(svAll, &op[19 * vLen], vsum19);
        svst1_f32(svAll, &op[20 * vLen], vsum20);
        svst1_f32(svAll, &op[21 * vLen], vsum21);
        svst1_f32(svAll, &op[22 * vLen], vsum22);
        svst1_f32(svAll, &op[23 * vLen], vsum23);
        svst1_f32(svAll, &op[24 * vLen], vsum24);
        svst1_f32(svAll, &op[25 * vLen], vsum25);
        svst1_f32(svAll, &op[26 * vLen], vsum26);
        svst1_f32(svAll, &op[27 * vLen], vsum27);
        svst1_f32(svAll, &op[28 * vLen], vsum28);
        svst1_f32(svAll, &op[29 * vLen], vsum29);
        svst1_f32(svAll, &op[30 * vLen], vsum30);
        svst1_f32(svAll, &op[31 * vLen], vsum31);
      }
    }
  } else if (block_size == 16 * vLen) {
    // unrolling 16 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      svfloat32_t vsum8 = svdup_n_f32(0);
      svfloat32_t vsum9 = svdup_n_f32(0);
      svfloat32_t vsum10 = svdup_n_f32(0);
      svfloat32_t vsum11 = svdup_n_f32(0);
      svfloat32_t vsum12 = svdup_n_f32(0);
      svfloat32_t vsum13 = svdup_n_f32(0);
      svfloat32_t vsum14 = svdup_n_f32(0);
      svfloat32_t vsum15 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        float bio{};
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        if (scale_bias) {
          bio = wgt * scale_bias[2 * idx + 1];
          wgt = wgt * scale_bias[2 * idx];
        }
        svfloat32_t vbio = svdup_n_f32(bio);
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const uint8_t* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])),
            svadd_f32_x(svAll, vsum0, vbio));
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])),
            svadd_f32_x(svAll, vsum1, vbio));
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])),
            svadd_f32_x(svAll, vsum2, vbio));
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])),
            svadd_f32_x(svAll, vsum3, vbio));
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])),
            svadd_f32_x(svAll, vsum4, vbio));
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])),
            svadd_f32_x(svAll, vsum5, vbio));
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])),
            svadd_f32_x(svAll, vsum6, vbio));
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])),
            svadd_f32_x(svAll, vsum7, vbio));
        vsum8 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])),
            svadd_f32_x(svAll, vsum8, vbio));
        vsum9 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])),
            svadd_f32_x(svAll, vsum9, vbio));
        vsum10 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])),
            svadd_f32_x(svAll, vsum10, vbio));
        vsum11 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])),
            svadd_f32_x(svAll, vsum11, vbio));
        vsum12 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])),
            svadd_f32_x(svAll, vsum12, vbio));
        vsum13 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])),
            svadd_f32_x(svAll, vsum13, vbio));
        vsum14 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])),
            svadd_f32_x(svAll, vsum14, vbio));
        vsum15 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])),
            svadd_f32_x(svAll, vsum15, vbio));
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
        svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv));
        svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv));
        svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv));
        svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv));
        svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv));
        svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv));
        svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv));
        svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
        svst1_f32(svAll, &op[8 * vLen], vsum8);
        svst1_f32(svAll, &op[9 * vLen], vsum9);
        svst1_f32(svAll, &op[10 * vLen], vsum10);
        svst1_f32(svAll, &op[11 * vLen], vsum11);
        svst1_f32(svAll, &op[12 * vLen], vsum12);
        svst1_f32(svAll, &op[13 * vLen], vsum13);
        svst1_f32(svAll, &op[14 * vLen], vsum14);
        svst1_f32(svAll, &op[15 * vLen], vsum15);
      }
    }
  } else if (block_size == 8 * vLen) {
    // unrolling 8 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        float bio{};
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        if (scale_bias) {
          bio = wgt * scale_bias[2 * idx + 1];
          wgt = wgt * scale_bias[2 * idx];
        }
        svfloat32_t vbio = svdup_n_f32(bio);
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const uint8_t* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])),
            svadd_f32_x(svAll, vsum0, vbio));
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])),
            svadd_f32_x(svAll, vsum1, vbio));
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])),
            svadd_f32_x(svAll, vsum2, vbio));
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])),
            svadd_f32_x(svAll, vsum3, vbio));
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])),
            svadd_f32_x(svAll, vsum4, vbio));
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])),
            svadd_f32_x(svAll, vsum5, vbio));
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])),
            svadd_f32_x(svAll, vsum6, vbio));
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])),
            svadd_f32_x(svAll, vsum7, vbio));
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
      }
    }
  } else if (block_size == 4 * vLen) {
    // unrolling 4 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        float bio{};
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        if (scale_bias) {
          bio = wgt * scale_bias[2 * idx + 1];
          wgt = wgt * scale_bias[2 * idx];
        }
        svfloat32_t vbio = svdup_n_f32(bio);
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const uint8_t* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])),
            svadd_f32_x(svAll, vsum0, vbio));
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])),
            svadd_f32_x(svAll, vsum1, vbio));
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])),
            svadd_f32_x(svAll, vsum2, vbio));
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])),
            svadd_f32_x(svAll, vsum3, vbio));
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
      }
    }
  } else if (block_size == 2 * vLen) {
    // unrolling 2 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        float bio{};
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        if (scale_bias) {
          bio = wgt * scale_bias[2 * idx + 1];
          wgt = wgt * scale_bias[2 * idx];
        }
        svfloat32_t vbio = svdup_n_f32(bio);
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const uint8_t* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])),
            svadd_f32_x(svAll, vsum0, vbio));
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])),
            svadd_f32_x(svAll, vsum1, vbio));
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
      }
    }
  } else {
    // generic code:
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      memset(op, 0, sizeof(float) * block_size);
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        // unimplemented
        float wgt = 1.f;
        float bio{};
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        if (scale_bias) {
          bio = wgt * scale_bias[2 * idx + 1];
          wgt = wgt * scale_bias[2 * idx];
        }
        svfloat32_t vbio = svdup_n_f32(bio);
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const uint8_t* ip = &input[idx * block_size];
        svbool_t pg;
        for (int64_t k = 0;
             svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size));
             k += vLen) {
          svst1_f32(
              pg,
              &op[k],
              svmad_f32_x(
                  pg,
                  vwgt,
                  svcvt_f32_u32_x(pg, svld1ub_u32(pg, &ip[k])),
                  svadd_f32_x(pg, svld1_f32(pg, &op[k]), vbio)));
        }

        ++pos;
      }
      const int64_t length = end_offset - start_offset;

      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svbool_t pg;
        for (int64_t j = 0;
             svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size));
             j += vLen) {
          svst1_f32(
              pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv));
        }
      }
    }
  }
  return pos == index_size;
}
bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const uint8_t* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int32_t_uint8_t_float__sve<false>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}
bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const uint8_t* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int32_t_uint8_t_float__sve<true>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}

template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int64_t_uint8_t_float__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const uint8_t* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  const svbool_t svAll = svptrue_b32();
  const auto vLen = static_cast<int64_t>(svcntw());
  int64_t pos = 0;
  if (block_size == 32 * vLen) {
    // unrolling 32 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      svfloat32_t vsum8 = svdup_n_f32(0);
      svfloat32_t vsum9 = svdup_n_f32(0);
      svfloat32_t vsum10 = svdup_n_f32(0);
      svfloat32_t vsum11 = svdup_n_f32(0);
      svfloat32_t vsum12 = svdup_n_f32(0);
      svfloat32_t vsum13 = svdup_n_f32(0);
      svfloat32_t vsum14 = svdup_n_f32(0);
      svfloat32_t vsum15 = svdup_n_f32(0);
      svfloat32_t vsum16 = svdup_n_f32(0);
      svfloat32_t vsum17 = svdup_n_f32(0);
      svfloat32_t vsum18 = svdup_n_f32(0);
      svfloat32_t vsum19 = svdup_n_f32(0);
      svfloat32_t vsum20 = svdup_n_f32(0);
      svfloat32_t vsum21 = svdup_n_f32(0);
      svfloat32_t vsum22 = svdup_n_f32(0);
      svfloat32_t vsum23 = svdup_n_f32(0);
      svfloat32_t vsum24 = svdup_n_f32(0);
      svfloat32_t vsum25 = svdup_n_f32(0);
      svfloat32_t vsum26 = svdup_n_f32(0);
      svfloat32_t vsum27 = svdup_n_f32(0);
      svfloat32_t vsum28 = svdup_n_f32(0);
      svfloat32_t vsum29 = svdup_n_f32(0);
      svfloat32_t vsum30 = svdup_n_f32(0);
      svfloat32_t vsum31 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        float bio{};
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        if (scale_bias) {
          bio = wgt * scale_bias[2 * idx + 1];
          wgt = wgt * scale_bias[2 * idx];
        }
        svfloat32_t vbio = svdup_n_f32(bio);
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const uint8_t* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])),
            svadd_f32_x(svAll, vsum0, vbio));
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])),
            svadd_f32_x(svAll, vsum1, vbio));
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])),
            svadd_f32_x(svAll, vsum2, vbio));
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])),
            svadd_f32_x(svAll, vsum3, vbio));
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])),
            svadd_f32_x(svAll, vsum4, vbio));
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])),
            svadd_f32_x(svAll, vsum5, vbio));
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])),
            svadd_f32_x(svAll, vsum6, vbio));
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])),
            svadd_f32_x(svAll, vsum7, vbio));
        vsum8 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])),
            svadd_f32_x(svAll, vsum8, vbio));
        vsum9 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])),
            svadd_f32_x(svAll, vsum9, vbio));
        vsum10 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])),
            svadd_f32_x(svAll, vsum10, vbio));
        vsum11 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])),
            svadd_f32_x(svAll, vsum11, vbio));
        vsum12 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])),
            svadd_f32_x(svAll, vsum12, vbio));
        vsum13 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])),
            svadd_f32_x(svAll, vsum13, vbio));
        vsum14 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])),
            svadd_f32_x(svAll, vsum14, vbio));
        vsum15 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])),
            svadd_f32_x(svAll, vsum15, vbio));
        vsum16 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[16 * vLen])),
            svadd_f32_x(svAll, vsum16, vbio));
        vsum17 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[17 * vLen])),
            svadd_f32_x(svAll, vsum17, vbio));
        vsum18 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[18 * vLen])),
            svadd_f32_x(svAll, vsum18, vbio));
        vsum19 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[19 * vLen])),
            svadd_f32_x(svAll, vsum19, vbio));
        vsum20 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[20 * vLen])),
            svadd_f32_x(svAll, vsum20, vbio));
        vsum21 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[21 * vLen])),
            svadd_f32_x(svAll, vsum21, vbio));
        vsum22 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[22 * vLen])),
            svadd_f32_x(svAll, vsum22, vbio));
        vsum23 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[23 * vLen])),
            svadd_f32_x(svAll, vsum23, vbio));
        vsum24 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[24 * vLen])),
            svadd_f32_x(svAll, vsum24, vbio));
        vsum25 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[25 * vLen])),
            svadd_f32_x(svAll, vsum25, vbio));
        vsum26 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[26 * vLen])),
            svadd_f32_x(svAll, vsum26, vbio));
        vsum27 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[27 * vLen])),
            svadd_f32_x(svAll, vsum27, vbio));
        vsum28 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[28 * vLen])),
            svadd_f32_x(svAll, vsum28, vbio));
        vsum29 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[29 * vLen])),
            svadd_f32_x(svAll, vsum29, vbio));
        vsum30 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[30 * vLen])),
            svadd_f32_x(svAll, vsum30, vbio));
        vsum31 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[31 * vLen])),
            svadd_f32_x(svAll, vsum31, vbio));
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
        svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv));
        svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv));
        svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv));
        svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv));
        svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv));
        svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv));
        svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv));
        svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv));
        svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv));
        svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv));
        svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv));
        svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv));
        svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv));
        svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv));
        svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv));
        svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv));
        svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv));
        svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv));
        svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv));
        svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv));
        svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv));
        svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv));
        svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv));
        svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
        svst1_f32(svAll, &op[8 * vLen], vsum8);
        svst1_f32(svAll, &op[9 * vLen], vsum9);
        svst1_f32(svAll, &op[10 * vLen], vsum10);
        svst1_f32(svAll, &op[11 * vLen], vsum11);
        svst1_f32(svAll, &op[12 * vLen], vsum12);
        svst1_f32(svAll, &op[13 * vLen], vsum13);
        svst1_f32(svAll, &op[14 * vLen], vsum14);
        svst1_f32(svAll, &op[15 * vLen], vsum15);
        svst1_f32(svAll, &op[16 * vLen], vsum16);
        svst1_f32(svAll, &op[17 * vLen], vsum17);
        svst1_f32(svAll, &op[18 * vLen], vsum18);
        svst1_f32(svAll, &op[19 * vLen], vsum19);
        svst1_f32(svAll, &op[20 * vLen], vsum20);
        svst1_f32(svAll, &op[21 * vLen], vsum21);
        svst1_f32(svAll, &op[22 * vLen], vsum22);
        svst1_f32(svAll, &op[23 * vLen], vsum23);
        svst1_f32(svAll, &op[24 * vLen], vsum24);
        svst1_f32(svAll, &op[25 * vLen], vsum25);
        svst1_f32(svAll, &op[26 * vLen], vsum26);
        svst1_f32(svAll, &op[27 * vLen], vsum27);
        svst1_f32(svAll, &op[28 * vLen], vsum28);
        svst1_f32(svAll, &op[29 * vLen], vsum29);
        svst1_f32(svAll, &op[30 * vLen], vsum30);
        svst1_f32(svAll, &op[31 * vLen], vsum31);
      }
    }
  } else if (block_size == 16 * vLen) {
    // unrolling 16 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      svfloat32_t vsum8 = svdup_n_f32(0);
      svfloat32_t vsum9 = svdup_n_f32(0);
      svfloat32_t vsum10 = svdup_n_f32(0);
      svfloat32_t vsum11 = svdup_n_f32(0);
      svfloat32_t vsum12 = svdup_n_f32(0);
      svfloat32_t vsum13 = svdup_n_f32(0);
      svfloat32_t vsum14 = svdup_n_f32(0);
      svfloat32_t vsum15 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        float bio{};
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        if (scale_bias) {
          bio = wgt * scale_bias[2 * idx + 1];
          wgt = wgt * scale_bias[2 * idx];
        }
        svfloat32_t vbio = svdup_n_f32(bio);
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const uint8_t* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])),
            svadd_f32_x(svAll, vsum0, vbio));
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])),
            svadd_f32_x(svAll, vsum1, vbio));
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])),
            svadd_f32_x(svAll, vsum2, vbio));
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])),
            svadd_f32_x(svAll, vsum3, vbio));
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])),
            svadd_f32_x(svAll, vsum4, vbio));
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])),
            svadd_f32_x(svAll, vsum5, vbio));
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])),
            svadd_f32_x(svAll, vsum6, vbio));
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])),
            svadd_f32_x(svAll, vsum7, vbio));
        vsum8 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])),
            svadd_f32_x(svAll, vsum8, vbio));
        vsum9 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])),
            svadd_f32_x(svAll, vsum9, vbio));
        vsum10 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])),
            svadd_f32_x(svAll, vsum10, vbio));
        vsum11 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])),
            svadd_f32_x(svAll, vsum11, vbio));
        vsum12 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])),
            svadd_f32_x(svAll, vsum12, vbio));
        vsum13 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])),
            svadd_f32_x(svAll, vsum13, vbio));
        vsum14 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])),
            svadd_f32_x(svAll, vsum14, vbio));
        vsum15 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])),
            svadd_f32_x(svAll, vsum15, vbio));
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
        svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv));
        svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv));
        svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv));
        svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv));
        svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv));
        svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv));
        svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv));
        svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
        svst1_f32(svAll, &op[8 * vLen], vsum8);
        svst1_f32(svAll, &op[9 * vLen], vsum9);
        svst1_f32(svAll, &op[10 * vLen], vsum10);
        svst1_f32(svAll, &op[11 * vLen], vsum11);
        svst1_f32(svAll, &op[12 * vLen], vsum12);
        svst1_f32(svAll, &op[13 * vLen], vsum13);
        svst1_f32(svAll, &op[14 * vLen], vsum14);
        svst1_f32(svAll, &op[15 * vLen], vsum15);
      }
    }
  } else if (block_size == 8 * vLen) {
    // unrolling 8 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      svfloat32_t vsum4 = svdup_n_f32(0);
      svfloat32_t vsum5 = svdup_n_f32(0);
      svfloat32_t vsum6 = svdup_n_f32(0);
      svfloat32_t vsum7 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        float bio{};
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        if (scale_bias) {
          bio = wgt * scale_bias[2 * idx + 1];
          wgt = wgt * scale_bias[2 * idx];
        }
        svfloat32_t vbio = svdup_n_f32(bio);
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const uint8_t* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])),
            svadd_f32_x(svAll, vsum0, vbio));
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])),
            svadd_f32_x(svAll, vsum1, vbio));
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])),
            svadd_f32_x(svAll, vsum2, vbio));
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])),
            svadd_f32_x(svAll, vsum3, vbio));
        vsum4 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])),
            svadd_f32_x(svAll, vsum4, vbio));
        vsum5 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])),
            svadd_f32_x(svAll, vsum5, vbio));
        vsum6 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])),
            svadd_f32_x(svAll, vsum6, vbio));
        vsum7 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])),
            svadd_f32_x(svAll, vsum7, vbio));
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
        svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv));
        svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv));
        svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv));
        svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
        svst1_f32(svAll, &op[4 * vLen], vsum4);
        svst1_f32(svAll, &op[5 * vLen], vsum5);
        svst1_f32(svAll, &op[6 * vLen], vsum6);
        svst1_f32(svAll, &op[7 * vLen], vsum7);
      }
    }
  } else if (block_size == 4 * vLen) {
    // unrolling 4 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      svfloat32_t vsum2 = svdup_n_f32(0);
      svfloat32_t vsum3 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        float bio{};
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        if (scale_bias) {
          bio = wgt * scale_bias[2 * idx + 1];
          wgt = wgt * scale_bias[2 * idx];
        }
        svfloat32_t vbio = svdup_n_f32(bio);
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const uint8_t* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])),
            svadd_f32_x(svAll, vsum0, vbio));
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])),
            svadd_f32_x(svAll, vsum1, vbio));
        vsum2 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])),
            svadd_f32_x(svAll, vsum2, vbio));
        vsum3 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])),
            svadd_f32_x(svAll, vsum3, vbio));
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
        svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv));
        svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
        svst1_f32(svAll, &op[2 * vLen], vsum2);
        svst1_f32(svAll, &op[3 * vLen], vsum3);
      }
    }
  } else if (block_size == 2 * vLen) {
    // unrolling 2 times
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      svfloat32_t vsum0 = svdup_n_f32(0);
      svfloat32_t vsum1 = svdup_n_f32(0);
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        float wgt = 1.f;
        float bio{};
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        if (scale_bias) {
          bio = wgt * scale_bias[2 * idx + 1];
          wgt = wgt * scale_bias[2 * idx];
        }
        svfloat32_t vbio = svdup_n_f32(bio);
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const uint8_t* const ip = &input[idx * block_size];
        // weight * input + out
        vsum0 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])),
            svadd_f32_x(svAll, vsum0, vbio));
        vsum1 = svmad_f32_x(
            svAll,
            vwgt,
            svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])),
            svadd_f32_x(svAll, vsum1, vbio));
        ++pos;
      }
      // Normalisation
      const int64_t length = end_offset - start_offset;
      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        const svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv));
        svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv));
      } else {
        svst1_f32(svAll, &op[0 * vLen], vsum0);
        svst1_f32(svAll, &op[1 * vLen], vsum1);
      }
    }
  } else {
    // generic code:
    for (int64_t i = 0; i < output_size; ++i) {
      float* const op = &out[i * block_size];
      memset(op, 0, sizeof(float) * block_size);
      if (pos != offsets[i] - offsets[0]) {
        return false;
      }
      int64_t start_offset = offsets[i];
      int64_t end_offset = offsets[i + 1];
      for (int64_t j = start_offset; j < end_offset; ++j) {
        const auto idx = indices[pos];
        if (idx < 0 || idx >= data_size) {
          return false;
        }
        // unimplemented
        float wgt = 1.f;
        float bio{};
        if (weights) {
          wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];
        }
        if (scale_bias) {
          bio = wgt * scale_bias[2 * idx + 1];
          wgt = wgt * scale_bias[2 * idx];
        }
        svfloat32_t vbio = svdup_n_f32(bio);
        const svfloat32_t vwgt = svdup_n_f32(wgt);
        const uint8_t* ip = &input[idx * block_size];
        svbool_t pg;
        for (int64_t k = 0;
             svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size));
             k += vLen) {
          svst1_f32(
              pg,
              &op[k],
              svmad_f32_x(
                  pg,
                  vwgt,
                  svcvt_f32_u32_x(pg, svld1ub_u32(pg, &ip[k])),
                  svadd_f32_x(pg, svld1_f32(pg, &op[k]), vbio)));
        }

        ++pos;
      }
      const int64_t length = end_offset - start_offset;

      if (normalize_by_lengths && length != 0) {
        const float len_inv = 1.0f / length;
        svfloat32_t vlen_inv = svdup_n_f32(len_inv);
        svbool_t pg;
        for (int64_t j = 0;
             svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size));
             j += vLen) {
          svst1_f32(
              pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv));
        }
      }
    }
  }
  return pos == index_size;
}
bool EmbeddingLookupIdx_int64_t_uint8_t_float_false__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const uint8_t* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int64_t_uint8_t_float__sve<false>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}
bool EmbeddingLookupIdx_int64_t_uint8_t_float_true__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const uint8_t* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int64_t_uint8_t_float__sve<true>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}

} // namespace caffe2
