//// --------------------------
//// ATTENTION:
//// THIS CODE IS AUTOGENERATED
//// BY sve_emblookup_codegen.py
//// DO NOT MODIFY!!!
//// --------------------------

#include <arm_sve.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <cstdint>
#include <cstring>
namespace caffe2 {

template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int32_t_float_float__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const float* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  const svbool_t svAll = svptrue_b32();
  const auto vLen = static_cast<int64_t>(svcntw());
  int64_t pos = 0;
  for (int64_t i = 0; i < output_size; ++i) {
    float* const op = &out[i * block_size];
    memset(op, 0, sizeof(float) * block_size);
    if (pos != offsets[i] - offsets[0]) {
      return false;
    }
    int64_t start_offset = offsets[i];
    int64_t end_offset = offsets[i + 1];
    int64_t j = start_offset;
    // unrolling 16 times
    while (j + 15 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      const auto idx4 = indices[pos + 4];
      const auto idx5 = indices[pos + 5];
      const auto idx6 = indices[pos + 6];
      const auto idx7 = indices[pos + 7];
      const auto idx8 = indices[pos + 8];
      const auto idx9 = indices[pos + 9];
      const auto idx10 = indices[pos + 10];
      const auto idx11 = indices[pos + 11];
      const auto idx12 = indices[pos + 12];
      const auto idx13 = indices[pos + 13];
      const auto idx14 = indices[pos + 14];
      const auto idx15 = indices[pos + 15];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      if (idx4 < 0 || idx4 >= data_size) {
        return false;
      }
      if (idx5 < 0 || idx5 >= data_size) {
        return false;
      }
      if (idx6 < 0 || idx6 >= data_size) {
        return false;
      }
      if (idx7 < 0 || idx7 >= data_size) {
        return false;
      }
      if (idx8 < 0 || idx8 >= data_size) {
        return false;
      }
      if (idx9 < 0 || idx9 >= data_size) {
        return false;
      }
      if (idx10 < 0 || idx10 >= data_size) {
        return false;
      }
      if (idx11 < 0 || idx11 >= data_size) {
        return false;
      }
      if (idx12 < 0 || idx12 >= data_size) {
        return false;
      }
      if (idx13 < 0 || idx13 >= data_size) {
        return false;
      }
      if (idx14 < 0 || idx14 >= data_size) {
        return false;
      }
      if (idx15 < 0 || idx15 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float wgt4 = 1.f;
      float wgt5 = 1.f;
      float wgt6 = 1.f;
      float wgt7 = 1.f;
      float wgt8 = 1.f;
      float wgt9 = 1.f;
      float wgt10 = 1.f;
      float wgt11 = 1.f;
      float wgt12 = 1.f;
      float wgt13 = 1.f;
      float wgt14 = 1.f;
      float wgt15 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
        wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
        wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
        wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
        wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
        wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
        wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
        wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
        wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
        wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
        wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
        wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
        wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
      }
      const float* const ip0 = &input[idx0 * block_size];
      const float* const ip1 = &input[idx1 * block_size];
      const float* const ip2 = &input[idx2 * block_size];
      const float* const ip3 = &input[idx3 * block_size];
      const float* const ip4 = &input[idx4 * block_size];
      const float* const ip5 = &input[idx5 * block_size];
      const float* const ip6 = &input[idx6 * block_size];
      const float* const ip7 = &input[idx7 * block_size];
      const float* const ip8 = &input[idx8 * block_size];
      const float* const ip9 = &input[idx9 * block_size];
      const float* const ip10 = &input[idx10 * block_size];
      const float* const ip11 = &input[idx11 * block_size];
      const float* const ip12 = &input[idx12 * block_size];
      const float* const ip13 = &input[idx13 * block_size];
      const float* const ip14 = &input[idx14 * block_size];
      const float* const ip15 = &input[idx15 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
        output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
        output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
        output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
        output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4);
        output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5);
        output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6);
        output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7);
        output = svmla_x(svAll, output, svld1(svAll, &ip8[k]), wgt8);
        output = svmla_x(svAll, output, svld1(svAll, &ip9[k]), wgt9);
        output = svmla_x(svAll, output, svld1(svAll, &ip10[k]), wgt10);
        output = svmla_x(svAll, output, svld1(svAll, &ip11[k]), wgt11);
        output = svmla_x(svAll, output, svld1(svAll, &ip12[k]), wgt12);
        output = svmla_x(svAll, output, svld1(svAll, &ip13[k]), wgt13);
        output = svmla_x(svAll, output, svld1(svAll, &ip14[k]), wgt14);
        output = svmla_x(svAll, output, svld1(svAll, &ip15[k]), wgt15);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
        output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
        output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
        output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
        output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4);
        output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5);
        output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6);
        output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7);
        output = svmla_x(pg, output, svld1(svAll, &ip8[k]), wgt8);
        output = svmla_x(pg, output, svld1(svAll, &ip9[k]), wgt9);
        output = svmla_x(pg, output, svld1(svAll, &ip10[k]), wgt10);
        output = svmla_x(pg, output, svld1(svAll, &ip11[k]), wgt11);
        output = svmla_x(pg, output, svld1(svAll, &ip12[k]), wgt12);
        output = svmla_x(pg, output, svld1(svAll, &ip13[k]), wgt13);
        output = svmla_x(pg, output, svld1(svAll, &ip14[k]), wgt14);
        output = svmla_x(pg, output, svld1(svAll, &ip15[k]), wgt15);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 16;
      pos += 16;
    }
    // unrolling 8 times
    while (j + 7 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      const auto idx4 = indices[pos + 4];
      const auto idx5 = indices[pos + 5];
      const auto idx6 = indices[pos + 6];
      const auto idx7 = indices[pos + 7];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      if (idx4 < 0 || idx4 >= data_size) {
        return false;
      }
      if (idx5 < 0 || idx5 >= data_size) {
        return false;
      }
      if (idx6 < 0 || idx6 >= data_size) {
        return false;
      }
      if (idx7 < 0 || idx7 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float wgt4 = 1.f;
      float wgt5 = 1.f;
      float wgt6 = 1.f;
      float wgt7 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
        wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
        wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
        wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
        wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
      }
      const float* const ip0 = &input[idx0 * block_size];
      const float* const ip1 = &input[idx1 * block_size];
      const float* const ip2 = &input[idx2 * block_size];
      const float* const ip3 = &input[idx3 * block_size];
      const float* const ip4 = &input[idx4 * block_size];
      const float* const ip5 = &input[idx5 * block_size];
      const float* const ip6 = &input[idx6 * block_size];
      const float* const ip7 = &input[idx7 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
        output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
        output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
        output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
        output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4);
        output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5);
        output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6);
        output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
        output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
        output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
        output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
        output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4);
        output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5);
        output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6);
        output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 8;
      pos += 8;
    }
    // unrolling 4 times
    while (j + 3 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
      }
      const float* const ip0 = &input[idx0 * block_size];
      const float* const ip1 = &input[idx1 * block_size];
      const float* const ip2 = &input[idx2 * block_size];
      const float* const ip3 = &input[idx3 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
        output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
        output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
        output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
        output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
        output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
        output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 4;
      pos += 4;
    }
    // unrolling 2 times
    while (j + 1 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
      }
      const float* const ip0 = &input[idx0 * block_size];
      const float* const ip1 = &input[idx1 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
        output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
        output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 2;
      pos += 2;
    }
    // tail loop
    if (j < end_offset) {
      const auto idx0 = indices[pos + 0];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
      }
      const float* const ip0 = &input[idx0 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      pos ++;
    }
    const int64_t length = end_offset - start_offset;

    if (normalize_by_lengths && length != 0) {
      const float len_inv = 1.0f / length;
      svbool_t pg;
      int64_t j = 0;
      while (j + vLen - 1 < block_size) {
        svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
        j += vLen;
      }
      if (j < block_size) {
        pg = svwhilelt_b32_s64(j, block_size);
        svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
      }
    }
  }
  return pos == index_size;
}
bool EmbeddingLookupIdx_int32_t_float_float_false__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const float* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int32_t_float_float__sve<false>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}
bool EmbeddingLookupIdx_int32_t_float_float_true__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const float* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int32_t_float_float__sve<true>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}

template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int64_t_float_float__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const float* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  const svbool_t svAll = svptrue_b32();
  const auto vLen = static_cast<int64_t>(svcntw());
  int64_t pos = 0;
  for (int64_t i = 0; i < output_size; ++i) {
    float* const op = &out[i * block_size];
    memset(op, 0, sizeof(float) * block_size);
    if (pos != offsets[i] - offsets[0]) {
      return false;
    }
    int64_t start_offset = offsets[i];
    int64_t end_offset = offsets[i + 1];
    int64_t j = start_offset;
    // unrolling 16 times
    while (j + 15 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      const auto idx4 = indices[pos + 4];
      const auto idx5 = indices[pos + 5];
      const auto idx6 = indices[pos + 6];
      const auto idx7 = indices[pos + 7];
      const auto idx8 = indices[pos + 8];
      const auto idx9 = indices[pos + 9];
      const auto idx10 = indices[pos + 10];
      const auto idx11 = indices[pos + 11];
      const auto idx12 = indices[pos + 12];
      const auto idx13 = indices[pos + 13];
      const auto idx14 = indices[pos + 14];
      const auto idx15 = indices[pos + 15];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      if (idx4 < 0 || idx4 >= data_size) {
        return false;
      }
      if (idx5 < 0 || idx5 >= data_size) {
        return false;
      }
      if (idx6 < 0 || idx6 >= data_size) {
        return false;
      }
      if (idx7 < 0 || idx7 >= data_size) {
        return false;
      }
      if (idx8 < 0 || idx8 >= data_size) {
        return false;
      }
      if (idx9 < 0 || idx9 >= data_size) {
        return false;
      }
      if (idx10 < 0 || idx10 >= data_size) {
        return false;
      }
      if (idx11 < 0 || idx11 >= data_size) {
        return false;
      }
      if (idx12 < 0 || idx12 >= data_size) {
        return false;
      }
      if (idx13 < 0 || idx13 >= data_size) {
        return false;
      }
      if (idx14 < 0 || idx14 >= data_size) {
        return false;
      }
      if (idx15 < 0 || idx15 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float wgt4 = 1.f;
      float wgt5 = 1.f;
      float wgt6 = 1.f;
      float wgt7 = 1.f;
      float wgt8 = 1.f;
      float wgt9 = 1.f;
      float wgt10 = 1.f;
      float wgt11 = 1.f;
      float wgt12 = 1.f;
      float wgt13 = 1.f;
      float wgt14 = 1.f;
      float wgt15 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
        wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
        wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
        wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
        wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
        wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
        wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
        wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
        wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
        wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
        wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
        wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
        wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
      }
      const float* const ip0 = &input[idx0 * block_size];
      const float* const ip1 = &input[idx1 * block_size];
      const float* const ip2 = &input[idx2 * block_size];
      const float* const ip3 = &input[idx3 * block_size];
      const float* const ip4 = &input[idx4 * block_size];
      const float* const ip5 = &input[idx5 * block_size];
      const float* const ip6 = &input[idx6 * block_size];
      const float* const ip7 = &input[idx7 * block_size];
      const float* const ip8 = &input[idx8 * block_size];
      const float* const ip9 = &input[idx9 * block_size];
      const float* const ip10 = &input[idx10 * block_size];
      const float* const ip11 = &input[idx11 * block_size];
      const float* const ip12 = &input[idx12 * block_size];
      const float* const ip13 = &input[idx13 * block_size];
      const float* const ip14 = &input[idx14 * block_size];
      const float* const ip15 = &input[idx15 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
        output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
        output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
        output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
        output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4);
        output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5);
        output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6);
        output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7);
        output = svmla_x(svAll, output, svld1(svAll, &ip8[k]), wgt8);
        output = svmla_x(svAll, output, svld1(svAll, &ip9[k]), wgt9);
        output = svmla_x(svAll, output, svld1(svAll, &ip10[k]), wgt10);
        output = svmla_x(svAll, output, svld1(svAll, &ip11[k]), wgt11);
        output = svmla_x(svAll, output, svld1(svAll, &ip12[k]), wgt12);
        output = svmla_x(svAll, output, svld1(svAll, &ip13[k]), wgt13);
        output = svmla_x(svAll, output, svld1(svAll, &ip14[k]), wgt14);
        output = svmla_x(svAll, output, svld1(svAll, &ip15[k]), wgt15);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
        output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
        output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
        output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
        output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4);
        output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5);
        output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6);
        output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7);
        output = svmla_x(pg, output, svld1(svAll, &ip8[k]), wgt8);
        output = svmla_x(pg, output, svld1(svAll, &ip9[k]), wgt9);
        output = svmla_x(pg, output, svld1(svAll, &ip10[k]), wgt10);
        output = svmla_x(pg, output, svld1(svAll, &ip11[k]), wgt11);
        output = svmla_x(pg, output, svld1(svAll, &ip12[k]), wgt12);
        output = svmla_x(pg, output, svld1(svAll, &ip13[k]), wgt13);
        output = svmla_x(pg, output, svld1(svAll, &ip14[k]), wgt14);
        output = svmla_x(pg, output, svld1(svAll, &ip15[k]), wgt15);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 16;
      pos += 16;
    }
    // unrolling 8 times
    while (j + 7 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      const auto idx4 = indices[pos + 4];
      const auto idx5 = indices[pos + 5];
      const auto idx6 = indices[pos + 6];
      const auto idx7 = indices[pos + 7];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      if (idx4 < 0 || idx4 >= data_size) {
        return false;
      }
      if (idx5 < 0 || idx5 >= data_size) {
        return false;
      }
      if (idx6 < 0 || idx6 >= data_size) {
        return false;
      }
      if (idx7 < 0 || idx7 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float wgt4 = 1.f;
      float wgt5 = 1.f;
      float wgt6 = 1.f;
      float wgt7 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
        wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
        wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
        wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
        wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
      }
      const float* const ip0 = &input[idx0 * block_size];
      const float* const ip1 = &input[idx1 * block_size];
      const float* const ip2 = &input[idx2 * block_size];
      const float* const ip3 = &input[idx3 * block_size];
      const float* const ip4 = &input[idx4 * block_size];
      const float* const ip5 = &input[idx5 * block_size];
      const float* const ip6 = &input[idx6 * block_size];
      const float* const ip7 = &input[idx7 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
        output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
        output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
        output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
        output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4);
        output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5);
        output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6);
        output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
        output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
        output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
        output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
        output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4);
        output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5);
        output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6);
        output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 8;
      pos += 8;
    }
    // unrolling 4 times
    while (j + 3 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
      }
      const float* const ip0 = &input[idx0 * block_size];
      const float* const ip1 = &input[idx1 * block_size];
      const float* const ip2 = &input[idx2 * block_size];
      const float* const ip3 = &input[idx3 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
        output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
        output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
        output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
        output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
        output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
        output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 4;
      pos += 4;
    }
    // unrolling 2 times
    while (j + 1 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
      }
      const float* const ip0 = &input[idx0 * block_size];
      const float* const ip1 = &input[idx1 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
        output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
        output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 2;
      pos += 2;
    }
    // tail loop
    if (j < end_offset) {
      const auto idx0 = indices[pos + 0];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
      }
      const float* const ip0 = &input[idx0 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      pos ++;
    }
    const int64_t length = end_offset - start_offset;

    if (normalize_by_lengths && length != 0) {
      const float len_inv = 1.0f / length;
      svbool_t pg;
      int64_t j = 0;
      while (j + vLen - 1 < block_size) {
        svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
        j += vLen;
      }
      if (j < block_size) {
        pg = svwhilelt_b32_s64(j, block_size);
        svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
      }
    }
  }
  return pos == index_size;
}
bool EmbeddingLookupIdx_int64_t_float_float_false__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const float* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int64_t_float_float__sve<false>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}
bool EmbeddingLookupIdx_int64_t_float_float_true__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const float* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int64_t_float_float__sve<true>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}

template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int32_t_half_float__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::Half* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  const svbool_t svAll = svptrue_b32();
  const auto vLen = static_cast<int64_t>(svcntw());
  int64_t pos = 0;
  for (int64_t i = 0; i < output_size; ++i) {
    float* const op = &out[i * block_size];
    memset(op, 0, sizeof(float) * block_size);
    if (pos != offsets[i] - offsets[0]) {
      return false;
    }
    int64_t start_offset = offsets[i];
    int64_t end_offset = offsets[i + 1];
    int64_t j = start_offset;
    // unrolling 16 times
    while (j + 15 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      const auto idx4 = indices[pos + 4];
      const auto idx5 = indices[pos + 5];
      const auto idx6 = indices[pos + 6];
      const auto idx7 = indices[pos + 7];
      const auto idx8 = indices[pos + 8];
      const auto idx9 = indices[pos + 9];
      const auto idx10 = indices[pos + 10];
      const auto idx11 = indices[pos + 11];
      const auto idx12 = indices[pos + 12];
      const auto idx13 = indices[pos + 13];
      const auto idx14 = indices[pos + 14];
      const auto idx15 = indices[pos + 15];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      if (idx4 < 0 || idx4 >= data_size) {
        return false;
      }
      if (idx5 < 0 || idx5 >= data_size) {
        return false;
      }
      if (idx6 < 0 || idx6 >= data_size) {
        return false;
      }
      if (idx7 < 0 || idx7 >= data_size) {
        return false;
      }
      if (idx8 < 0 || idx8 >= data_size) {
        return false;
      }
      if (idx9 < 0 || idx9 >= data_size) {
        return false;
      }
      if (idx10 < 0 || idx10 >= data_size) {
        return false;
      }
      if (idx11 < 0 || idx11 >= data_size) {
        return false;
      }
      if (idx12 < 0 || idx12 >= data_size) {
        return false;
      }
      if (idx13 < 0 || idx13 >= data_size) {
        return false;
      }
      if (idx14 < 0 || idx14 >= data_size) {
        return false;
      }
      if (idx15 < 0 || idx15 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float wgt4 = 1.f;
      float wgt5 = 1.f;
      float wgt6 = 1.f;
      float wgt7 = 1.f;
      float wgt8 = 1.f;
      float wgt9 = 1.f;
      float wgt10 = 1.f;
      float wgt11 = 1.f;
      float wgt12 = 1.f;
      float wgt13 = 1.f;
      float wgt14 = 1.f;
      float wgt15 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
        wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
        wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
        wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
        wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
        wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
        wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
        wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
        wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
        wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
        wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
        wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
        wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
      }
      const at::Half* const ip0 = &input[idx0 * block_size];
      const at::Half* const ip1 = &input[idx1 * block_size];
      const at::Half* const ip2 = &input[idx2 * block_size];
      const at::Half* const ip3 = &input[idx3 * block_size];
      const at::Half* const ip4 = &input[idx4 * block_size];
      const at::Half* const ip5 = &input[idx5 * block_size];
      const at::Half* const ip6 = &input[idx6 * block_size];
      const at::Half* const ip7 = &input[idx7 * block_size];
      const at::Half* const ip8 = &input[idx8 * block_size];
      const at::Half* const ip9 = &input[idx9 * block_size];
      const at::Half* const ip10 = &input[idx10 * block_size];
      const at::Half* const ip11 = &input[idx11 * block_size];
      const at::Half* const ip12 = &input[idx12 * block_size];
      const at::Half* const ip13 = &input[idx13 * block_size];
      const at::Half* const ip14 = &input[idx14 * block_size];
      const at::Half* const ip15 = &input[idx15 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
        auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
        auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
        auto input4 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k]))));
        auto input5 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k]))));
        auto input6 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k]))));
        auto input7 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k]))));
        auto input8 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip8[k]))));
        auto input9 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip9[k]))));
        auto input10 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip10[k]))));
        auto input11 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip11[k]))));
        auto input12 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip12[k]))));
        auto input13 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip13[k]))));
        auto input14 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip14[k]))));
        auto input15 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip15[k]))));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        output = svmla_x(svAll, output, input4, wgt4);
        output = svmla_x(svAll, output, input5, wgt5);
        output = svmla_x(svAll, output, input6, wgt6);
        output = svmla_x(svAll, output, input7, wgt7);
        output = svmla_x(svAll, output, input8, wgt8);
        output = svmla_x(svAll, output, input9, wgt9);
        output = svmla_x(svAll, output, input10, wgt10);
        output = svmla_x(svAll, output, input11, wgt11);
        output = svmla_x(svAll, output, input12, wgt12);
        output = svmla_x(svAll, output, input13, wgt13);
        output = svmla_x(svAll, output, input14, wgt14);
        output = svmla_x(svAll, output, input15, wgt15);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
        auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
        auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
        auto input4 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k]))));
        auto input5 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k]))));
        auto input6 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k]))));
        auto input7 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k]))));
        auto input8 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip8[k]))));
        auto input9 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip9[k]))));
        auto input10 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip10[k]))));
        auto input11 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip11[k]))));
        auto input12 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip12[k]))));
        auto input13 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip13[k]))));
        auto input14 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip14[k]))));
        auto input15 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip15[k]))));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        output = svmla_x(pg, output, input4, wgt4);
        output = svmla_x(pg, output, input5, wgt5);
        output = svmla_x(pg, output, input6, wgt6);
        output = svmla_x(pg, output, input7, wgt7);
        output = svmla_x(pg, output, input8, wgt8);
        output = svmla_x(pg, output, input9, wgt9);
        output = svmla_x(pg, output, input10, wgt10);
        output = svmla_x(pg, output, input11, wgt11);
        output = svmla_x(pg, output, input12, wgt12);
        output = svmla_x(pg, output, input13, wgt13);
        output = svmla_x(pg, output, input14, wgt14);
        output = svmla_x(pg, output, input15, wgt15);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 16;
      pos += 16;
    }
    // unrolling 8 times
    while (j + 7 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      const auto idx4 = indices[pos + 4];
      const auto idx5 = indices[pos + 5];
      const auto idx6 = indices[pos + 6];
      const auto idx7 = indices[pos + 7];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      if (idx4 < 0 || idx4 >= data_size) {
        return false;
      }
      if (idx5 < 0 || idx5 >= data_size) {
        return false;
      }
      if (idx6 < 0 || idx6 >= data_size) {
        return false;
      }
      if (idx7 < 0 || idx7 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float wgt4 = 1.f;
      float wgt5 = 1.f;
      float wgt6 = 1.f;
      float wgt7 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
        wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
        wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
        wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
        wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
      }
      const at::Half* const ip0 = &input[idx0 * block_size];
      const at::Half* const ip1 = &input[idx1 * block_size];
      const at::Half* const ip2 = &input[idx2 * block_size];
      const at::Half* const ip3 = &input[idx3 * block_size];
      const at::Half* const ip4 = &input[idx4 * block_size];
      const at::Half* const ip5 = &input[idx5 * block_size];
      const at::Half* const ip6 = &input[idx6 * block_size];
      const at::Half* const ip7 = &input[idx7 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
        auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
        auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
        auto input4 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k]))));
        auto input5 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k]))));
        auto input6 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k]))));
        auto input7 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k]))));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        output = svmla_x(svAll, output, input4, wgt4);
        output = svmla_x(svAll, output, input5, wgt5);
        output = svmla_x(svAll, output, input6, wgt6);
        output = svmla_x(svAll, output, input7, wgt7);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
        auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
        auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
        auto input4 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k]))));
        auto input5 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k]))));
        auto input6 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k]))));
        auto input7 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k]))));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        output = svmla_x(pg, output, input4, wgt4);
        output = svmla_x(pg, output, input5, wgt5);
        output = svmla_x(pg, output, input6, wgt6);
        output = svmla_x(pg, output, input7, wgt7);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 8;
      pos += 8;
    }
    // unrolling 4 times
    while (j + 3 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
      }
      const at::Half* const ip0 = &input[idx0 * block_size];
      const at::Half* const ip1 = &input[idx1 * block_size];
      const at::Half* const ip2 = &input[idx2 * block_size];
      const at::Half* const ip3 = &input[idx3 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
        auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
        auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
        auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
        auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 4;
      pos += 4;
    }
    // unrolling 2 times
    while (j + 1 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
      }
      const at::Half* const ip0 = &input[idx0 * block_size];
      const at::Half* const ip1 = &input[idx1 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 2;
      pos += 2;
    }
    // tail loop
    if (j < end_offset) {
      const auto idx0 = indices[pos + 0];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
      }
      const at::Half* const ip0 = &input[idx0 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        output = svmla_x(svAll, output, input0, wgt0);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        output = svmla_x(pg, output, input0, wgt0);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      pos ++;
    }
    const int64_t length = end_offset - start_offset;

    if (normalize_by_lengths && length != 0) {
      const float len_inv = 1.0f / length;
      svbool_t pg;
      int64_t j = 0;
      while (j + vLen - 1 < block_size) {
        svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
        j += vLen;
      }
      if (j < block_size) {
        pg = svwhilelt_b32_s64(j, block_size);
        svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
      }
    }
  }
  return pos == index_size;
}
bool EmbeddingLookupIdx_int32_t_half_float_false__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::Half* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int32_t_half_float__sve<false>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}
bool EmbeddingLookupIdx_int32_t_half_float_true__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::Half* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int32_t_half_float__sve<true>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}

template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int64_t_half_float__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::Half* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  const svbool_t svAll = svptrue_b32();
  const auto vLen = static_cast<int64_t>(svcntw());
  int64_t pos = 0;
  for (int64_t i = 0; i < output_size; ++i) {
    float* const op = &out[i * block_size];
    memset(op, 0, sizeof(float) * block_size);
    if (pos != offsets[i] - offsets[0]) {
      return false;
    }
    int64_t start_offset = offsets[i];
    int64_t end_offset = offsets[i + 1];
    int64_t j = start_offset;
    // unrolling 16 times
    while (j + 15 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      const auto idx4 = indices[pos + 4];
      const auto idx5 = indices[pos + 5];
      const auto idx6 = indices[pos + 6];
      const auto idx7 = indices[pos + 7];
      const auto idx8 = indices[pos + 8];
      const auto idx9 = indices[pos + 9];
      const auto idx10 = indices[pos + 10];
      const auto idx11 = indices[pos + 11];
      const auto idx12 = indices[pos + 12];
      const auto idx13 = indices[pos + 13];
      const auto idx14 = indices[pos + 14];
      const auto idx15 = indices[pos + 15];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      if (idx4 < 0 || idx4 >= data_size) {
        return false;
      }
      if (idx5 < 0 || idx5 >= data_size) {
        return false;
      }
      if (idx6 < 0 || idx6 >= data_size) {
        return false;
      }
      if (idx7 < 0 || idx7 >= data_size) {
        return false;
      }
      if (idx8 < 0 || idx8 >= data_size) {
        return false;
      }
      if (idx9 < 0 || idx9 >= data_size) {
        return false;
      }
      if (idx10 < 0 || idx10 >= data_size) {
        return false;
      }
      if (idx11 < 0 || idx11 >= data_size) {
        return false;
      }
      if (idx12 < 0 || idx12 >= data_size) {
        return false;
      }
      if (idx13 < 0 || idx13 >= data_size) {
        return false;
      }
      if (idx14 < 0 || idx14 >= data_size) {
        return false;
      }
      if (idx15 < 0 || idx15 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float wgt4 = 1.f;
      float wgt5 = 1.f;
      float wgt6 = 1.f;
      float wgt7 = 1.f;
      float wgt8 = 1.f;
      float wgt9 = 1.f;
      float wgt10 = 1.f;
      float wgt11 = 1.f;
      float wgt12 = 1.f;
      float wgt13 = 1.f;
      float wgt14 = 1.f;
      float wgt15 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
        wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
        wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
        wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
        wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
        wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
        wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
        wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
        wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
        wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
        wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
        wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
        wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
      }
      const at::Half* const ip0 = &input[idx0 * block_size];
      const at::Half* const ip1 = &input[idx1 * block_size];
      const at::Half* const ip2 = &input[idx2 * block_size];
      const at::Half* const ip3 = &input[idx3 * block_size];
      const at::Half* const ip4 = &input[idx4 * block_size];
      const at::Half* const ip5 = &input[idx5 * block_size];
      const at::Half* const ip6 = &input[idx6 * block_size];
      const at::Half* const ip7 = &input[idx7 * block_size];
      const at::Half* const ip8 = &input[idx8 * block_size];
      const at::Half* const ip9 = &input[idx9 * block_size];
      const at::Half* const ip10 = &input[idx10 * block_size];
      const at::Half* const ip11 = &input[idx11 * block_size];
      const at::Half* const ip12 = &input[idx12 * block_size];
      const at::Half* const ip13 = &input[idx13 * block_size];
      const at::Half* const ip14 = &input[idx14 * block_size];
      const at::Half* const ip15 = &input[idx15 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
        auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
        auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
        auto input4 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k]))));
        auto input5 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k]))));
        auto input6 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k]))));
        auto input7 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k]))));
        auto input8 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip8[k]))));
        auto input9 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip9[k]))));
        auto input10 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip10[k]))));
        auto input11 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip11[k]))));
        auto input12 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip12[k]))));
        auto input13 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip13[k]))));
        auto input14 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip14[k]))));
        auto input15 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip15[k]))));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        output = svmla_x(svAll, output, input4, wgt4);
        output = svmla_x(svAll, output, input5, wgt5);
        output = svmla_x(svAll, output, input6, wgt6);
        output = svmla_x(svAll, output, input7, wgt7);
        output = svmla_x(svAll, output, input8, wgt8);
        output = svmla_x(svAll, output, input9, wgt9);
        output = svmla_x(svAll, output, input10, wgt10);
        output = svmla_x(svAll, output, input11, wgt11);
        output = svmla_x(svAll, output, input12, wgt12);
        output = svmla_x(svAll, output, input13, wgt13);
        output = svmla_x(svAll, output, input14, wgt14);
        output = svmla_x(svAll, output, input15, wgt15);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
        auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
        auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
        auto input4 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k]))));
        auto input5 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k]))));
        auto input6 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k]))));
        auto input7 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k]))));
        auto input8 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip8[k]))));
        auto input9 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip9[k]))));
        auto input10 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip10[k]))));
        auto input11 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip11[k]))));
        auto input12 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip12[k]))));
        auto input13 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip13[k]))));
        auto input14 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip14[k]))));
        auto input15 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip15[k]))));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        output = svmla_x(pg, output, input4, wgt4);
        output = svmla_x(pg, output, input5, wgt5);
        output = svmla_x(pg, output, input6, wgt6);
        output = svmla_x(pg, output, input7, wgt7);
        output = svmla_x(pg, output, input8, wgt8);
        output = svmla_x(pg, output, input9, wgt9);
        output = svmla_x(pg, output, input10, wgt10);
        output = svmla_x(pg, output, input11, wgt11);
        output = svmla_x(pg, output, input12, wgt12);
        output = svmla_x(pg, output, input13, wgt13);
        output = svmla_x(pg, output, input14, wgt14);
        output = svmla_x(pg, output, input15, wgt15);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 16;
      pos += 16;
    }
    // unrolling 8 times
    while (j + 7 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      const auto idx4 = indices[pos + 4];
      const auto idx5 = indices[pos + 5];
      const auto idx6 = indices[pos + 6];
      const auto idx7 = indices[pos + 7];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      if (idx4 < 0 || idx4 >= data_size) {
        return false;
      }
      if (idx5 < 0 || idx5 >= data_size) {
        return false;
      }
      if (idx6 < 0 || idx6 >= data_size) {
        return false;
      }
      if (idx7 < 0 || idx7 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float wgt4 = 1.f;
      float wgt5 = 1.f;
      float wgt6 = 1.f;
      float wgt7 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
        wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
        wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
        wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
        wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
      }
      const at::Half* const ip0 = &input[idx0 * block_size];
      const at::Half* const ip1 = &input[idx1 * block_size];
      const at::Half* const ip2 = &input[idx2 * block_size];
      const at::Half* const ip3 = &input[idx3 * block_size];
      const at::Half* const ip4 = &input[idx4 * block_size];
      const at::Half* const ip5 = &input[idx5 * block_size];
      const at::Half* const ip6 = &input[idx6 * block_size];
      const at::Half* const ip7 = &input[idx7 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
        auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
        auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
        auto input4 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k]))));
        auto input5 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k]))));
        auto input6 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k]))));
        auto input7 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k]))));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        output = svmla_x(svAll, output, input4, wgt4);
        output = svmla_x(svAll, output, input5, wgt5);
        output = svmla_x(svAll, output, input6, wgt6);
        output = svmla_x(svAll, output, input7, wgt7);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
        auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
        auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
        auto input4 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k]))));
        auto input5 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k]))));
        auto input6 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k]))));
        auto input7 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k]))));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        output = svmla_x(pg, output, input4, wgt4);
        output = svmla_x(pg, output, input5, wgt5);
        output = svmla_x(pg, output, input6, wgt6);
        output = svmla_x(pg, output, input7, wgt7);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 8;
      pos += 8;
    }
    // unrolling 4 times
    while (j + 3 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
      }
      const at::Half* const ip0 = &input[idx0 * block_size];
      const at::Half* const ip1 = &input[idx1 * block_size];
      const at::Half* const ip2 = &input[idx2 * block_size];
      const at::Half* const ip3 = &input[idx3 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
        auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
        auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
        auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
        auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 4;
      pos += 4;
    }
    // unrolling 2 times
    while (j + 1 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
      }
      const at::Half* const ip0 = &input[idx0 * block_size];
      const at::Half* const ip1 = &input[idx1 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 2;
      pos += 2;
    }
    // tail loop
    if (j < end_offset) {
      const auto idx0 = indices[pos + 0];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
      }
      const at::Half* const ip0 = &input[idx0 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        output = svmla_x(svAll, output, input0, wgt0);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
        output = svmla_x(pg, output, input0, wgt0);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      pos ++;
    }
    const int64_t length = end_offset - start_offset;

    if (normalize_by_lengths && length != 0) {
      const float len_inv = 1.0f / length;
      svbool_t pg;
      int64_t j = 0;
      while (j + vLen - 1 < block_size) {
        svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
        j += vLen;
      }
      if (j < block_size) {
        pg = svwhilelt_b32_s64(j, block_size);
        svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
      }
    }
  }
  return pos == index_size;
}
bool EmbeddingLookupIdx_int64_t_half_float_false__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::Half* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int64_t_half_float__sve<false>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}
bool EmbeddingLookupIdx_int64_t_half_float_true__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::Half* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int64_t_half_float__sve<true>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}

template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int32_t_bfloat16_float__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::BFloat16* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  const svbool_t svAll = svptrue_b32();
  const auto vLen = static_cast<int64_t>(svcntw());
  int64_t pos = 0;
  for (int64_t i = 0; i < output_size; ++i) {
    float* const op = &out[i * block_size];
    memset(op, 0, sizeof(float) * block_size);
    if (pos != offsets[i] - offsets[0]) {
      return false;
    }
    int64_t start_offset = offsets[i];
    int64_t end_offset = offsets[i + 1];
    int64_t j = start_offset;
    // unrolling 16 times
    while (j + 15 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      const auto idx4 = indices[pos + 4];
      const auto idx5 = indices[pos + 5];
      const auto idx6 = indices[pos + 6];
      const auto idx7 = indices[pos + 7];
      const auto idx8 = indices[pos + 8];
      const auto idx9 = indices[pos + 9];
      const auto idx10 = indices[pos + 10];
      const auto idx11 = indices[pos + 11];
      const auto idx12 = indices[pos + 12];
      const auto idx13 = indices[pos + 13];
      const auto idx14 = indices[pos + 14];
      const auto idx15 = indices[pos + 15];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      if (idx4 < 0 || idx4 >= data_size) {
        return false;
      }
      if (idx5 < 0 || idx5 >= data_size) {
        return false;
      }
      if (idx6 < 0 || idx6 >= data_size) {
        return false;
      }
      if (idx7 < 0 || idx7 >= data_size) {
        return false;
      }
      if (idx8 < 0 || idx8 >= data_size) {
        return false;
      }
      if (idx9 < 0 || idx9 >= data_size) {
        return false;
      }
      if (idx10 < 0 || idx10 >= data_size) {
        return false;
      }
      if (idx11 < 0 || idx11 >= data_size) {
        return false;
      }
      if (idx12 < 0 || idx12 >= data_size) {
        return false;
      }
      if (idx13 < 0 || idx13 >= data_size) {
        return false;
      }
      if (idx14 < 0 || idx14 >= data_size) {
        return false;
      }
      if (idx15 < 0 || idx15 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float wgt4 = 1.f;
      float wgt5 = 1.f;
      float wgt6 = 1.f;
      float wgt7 = 1.f;
      float wgt8 = 1.f;
      float wgt9 = 1.f;
      float wgt10 = 1.f;
      float wgt11 = 1.f;
      float wgt12 = 1.f;
      float wgt13 = 1.f;
      float wgt14 = 1.f;
      float wgt15 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
        wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
        wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
        wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
        wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
        wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
        wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
        wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
        wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
        wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
        wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
        wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
        wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
      }
      const at::BFloat16* const ip0 = &input[idx0 * block_size];
      const at::BFloat16* const ip1 = &input[idx1 * block_size];
      const at::BFloat16* const ip2 = &input[idx2 * block_size];
      const at::BFloat16* const ip3 = &input[idx3 * block_size];
      const at::BFloat16* const ip4 = &input[idx4 * block_size];
      const at::BFloat16* const ip5 = &input[idx5 * block_size];
      const at::BFloat16* const ip6 = &input[idx6 * block_size];
      const at::BFloat16* const ip7 = &input[idx7 * block_size];
      const at::BFloat16* const ip8 = &input[idx8 * block_size];
      const at::BFloat16* const ip9 = &input[idx9 * block_size];
      const at::BFloat16* const ip10 = &input[idx10 * block_size];
      const at::BFloat16* const ip11 = &input[idx11 * block_size];
      const at::BFloat16* const ip12 = &input[idx12 * block_size];
      const at::BFloat16* const ip13 = &input[idx13 * block_size];
      const at::BFloat16* const ip14 = &input[idx14 * block_size];
      const at::BFloat16* const ip15 = &input[idx15 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        auto input1 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
        auto input2 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
        auto input3 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
        auto input4 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
        auto input5 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
        auto input6 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
        auto input7 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
        auto input8 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip8[k])), 16));
        auto input9 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip9[k])), 16));
        auto input10 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip10[k])), 16));
        auto input11 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip11[k])), 16));
        auto input12 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip12[k])), 16));
        auto input13 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip13[k])), 16));
        auto input14 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip14[k])), 16));
        auto input15 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip15[k])), 16));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        output = svmla_x(svAll, output, input4, wgt4);
        output = svmla_x(svAll, output, input5, wgt5);
        output = svmla_x(svAll, output, input6, wgt6);
        output = svmla_x(svAll, output, input7, wgt7);
        output = svmla_x(svAll, output, input8, wgt8);
        output = svmla_x(svAll, output, input9, wgt9);
        output = svmla_x(svAll, output, input10, wgt10);
        output = svmla_x(svAll, output, input11, wgt11);
        output = svmla_x(svAll, output, input12, wgt12);
        output = svmla_x(svAll, output, input13, wgt13);
        output = svmla_x(svAll, output, input14, wgt14);
        output = svmla_x(svAll, output, input15, wgt15);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        auto input1 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
        auto input2 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
        auto input3 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
        auto input4 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
        auto input5 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
        auto input6 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
        auto input7 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
        auto input8 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip8[k])), 16));
        auto input9 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip9[k])), 16));
        auto input10 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip10[k])), 16));
        auto input11 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip11[k])), 16));
        auto input12 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip12[k])), 16));
        auto input13 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip13[k])), 16));
        auto input14 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip14[k])), 16));
        auto input15 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip15[k])), 16));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        output = svmla_x(pg, output, input4, wgt4);
        output = svmla_x(pg, output, input5, wgt5);
        output = svmla_x(pg, output, input6, wgt6);
        output = svmla_x(pg, output, input7, wgt7);
        output = svmla_x(pg, output, input8, wgt8);
        output = svmla_x(pg, output, input9, wgt9);
        output = svmla_x(pg, output, input10, wgt10);
        output = svmla_x(pg, output, input11, wgt11);
        output = svmla_x(pg, output, input12, wgt12);
        output = svmla_x(pg, output, input13, wgt13);
        output = svmla_x(pg, output, input14, wgt14);
        output = svmla_x(pg, output, input15, wgt15);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 16;
      pos += 16;
    }
    // unrolling 8 times
    while (j + 7 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      const auto idx4 = indices[pos + 4];
      const auto idx5 = indices[pos + 5];
      const auto idx6 = indices[pos + 6];
      const auto idx7 = indices[pos + 7];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      if (idx4 < 0 || idx4 >= data_size) {
        return false;
      }
      if (idx5 < 0 || idx5 >= data_size) {
        return false;
      }
      if (idx6 < 0 || idx6 >= data_size) {
        return false;
      }
      if (idx7 < 0 || idx7 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float wgt4 = 1.f;
      float wgt5 = 1.f;
      float wgt6 = 1.f;
      float wgt7 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
        wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
        wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
        wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
        wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
      }
      const at::BFloat16* const ip0 = &input[idx0 * block_size];
      const at::BFloat16* const ip1 = &input[idx1 * block_size];
      const at::BFloat16* const ip2 = &input[idx2 * block_size];
      const at::BFloat16* const ip3 = &input[idx3 * block_size];
      const at::BFloat16* const ip4 = &input[idx4 * block_size];
      const at::BFloat16* const ip5 = &input[idx5 * block_size];
      const at::BFloat16* const ip6 = &input[idx6 * block_size];
      const at::BFloat16* const ip7 = &input[idx7 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        auto input1 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
        auto input2 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
        auto input3 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
        auto input4 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
        auto input5 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
        auto input6 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
        auto input7 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        output = svmla_x(svAll, output, input4, wgt4);
        output = svmla_x(svAll, output, input5, wgt5);
        output = svmla_x(svAll, output, input6, wgt6);
        output = svmla_x(svAll, output, input7, wgt7);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        auto input1 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
        auto input2 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
        auto input3 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
        auto input4 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
        auto input5 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
        auto input6 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
        auto input7 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        output = svmla_x(pg, output, input4, wgt4);
        output = svmla_x(pg, output, input5, wgt5);
        output = svmla_x(pg, output, input6, wgt6);
        output = svmla_x(pg, output, input7, wgt7);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 8;
      pos += 8;
    }
    // unrolling 4 times
    while (j + 3 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
      }
      const at::BFloat16* const ip0 = &input[idx0 * block_size];
      const at::BFloat16* const ip1 = &input[idx1 * block_size];
      const at::BFloat16* const ip2 = &input[idx2 * block_size];
      const at::BFloat16* const ip3 = &input[idx3 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        auto input1 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
        auto input2 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
        auto input3 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        auto input1 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
        auto input2 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
        auto input3 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 4;
      pos += 4;
    }
    // unrolling 2 times
    while (j + 1 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
      }
      const at::BFloat16* const ip0 = &input[idx0 * block_size];
      const at::BFloat16* const ip1 = &input[idx1 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        auto input1 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        auto input1 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 2;
      pos += 2;
    }
    // tail loop
    if (j < end_offset) {
      const auto idx0 = indices[pos + 0];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
      }
      const at::BFloat16* const ip0 = &input[idx0 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        output = svmla_x(svAll, output, input0, wgt0);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        output = svmla_x(pg, output, input0, wgt0);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      pos ++;
    }
    const int64_t length = end_offset - start_offset;

    if (normalize_by_lengths && length != 0) {
      const float len_inv = 1.0f / length;
      svbool_t pg;
      int64_t j = 0;
      while (j + vLen - 1 < block_size) {
        svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
        j += vLen;
      }
      if (j < block_size) {
        pg = svwhilelt_b32_s64(j, block_size);
        svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
      }
    }
  }
  return pos == index_size;
}
bool EmbeddingLookupIdx_int32_t_bfloat16_float_false__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::BFloat16* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int32_t_bfloat16_float__sve<false>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}
bool EmbeddingLookupIdx_int32_t_bfloat16_float_true__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::BFloat16* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int32_t_bfloat16_float__sve<true>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}

template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int64_t_bfloat16_float__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::BFloat16* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  const svbool_t svAll = svptrue_b32();
  const auto vLen = static_cast<int64_t>(svcntw());
  int64_t pos = 0;
  for (int64_t i = 0; i < output_size; ++i) {
    float* const op = &out[i * block_size];
    memset(op, 0, sizeof(float) * block_size);
    if (pos != offsets[i] - offsets[0]) {
      return false;
    }
    int64_t start_offset = offsets[i];
    int64_t end_offset = offsets[i + 1];
    int64_t j = start_offset;
    // unrolling 16 times
    while (j + 15 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      const auto idx4 = indices[pos + 4];
      const auto idx5 = indices[pos + 5];
      const auto idx6 = indices[pos + 6];
      const auto idx7 = indices[pos + 7];
      const auto idx8 = indices[pos + 8];
      const auto idx9 = indices[pos + 9];
      const auto idx10 = indices[pos + 10];
      const auto idx11 = indices[pos + 11];
      const auto idx12 = indices[pos + 12];
      const auto idx13 = indices[pos + 13];
      const auto idx14 = indices[pos + 14];
      const auto idx15 = indices[pos + 15];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      if (idx4 < 0 || idx4 >= data_size) {
        return false;
      }
      if (idx5 < 0 || idx5 >= data_size) {
        return false;
      }
      if (idx6 < 0 || idx6 >= data_size) {
        return false;
      }
      if (idx7 < 0 || idx7 >= data_size) {
        return false;
      }
      if (idx8 < 0 || idx8 >= data_size) {
        return false;
      }
      if (idx9 < 0 || idx9 >= data_size) {
        return false;
      }
      if (idx10 < 0 || idx10 >= data_size) {
        return false;
      }
      if (idx11 < 0 || idx11 >= data_size) {
        return false;
      }
      if (idx12 < 0 || idx12 >= data_size) {
        return false;
      }
      if (idx13 < 0 || idx13 >= data_size) {
        return false;
      }
      if (idx14 < 0 || idx14 >= data_size) {
        return false;
      }
      if (idx15 < 0 || idx15 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float wgt4 = 1.f;
      float wgt5 = 1.f;
      float wgt6 = 1.f;
      float wgt7 = 1.f;
      float wgt8 = 1.f;
      float wgt9 = 1.f;
      float wgt10 = 1.f;
      float wgt11 = 1.f;
      float wgt12 = 1.f;
      float wgt13 = 1.f;
      float wgt14 = 1.f;
      float wgt15 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
        wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
        wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
        wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
        wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
        wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
        wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
        wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
        wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
        wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
        wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
        wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
        wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
      }
      const at::BFloat16* const ip0 = &input[idx0 * block_size];
      const at::BFloat16* const ip1 = &input[idx1 * block_size];
      const at::BFloat16* const ip2 = &input[idx2 * block_size];
      const at::BFloat16* const ip3 = &input[idx3 * block_size];
      const at::BFloat16* const ip4 = &input[idx4 * block_size];
      const at::BFloat16* const ip5 = &input[idx5 * block_size];
      const at::BFloat16* const ip6 = &input[idx6 * block_size];
      const at::BFloat16* const ip7 = &input[idx7 * block_size];
      const at::BFloat16* const ip8 = &input[idx8 * block_size];
      const at::BFloat16* const ip9 = &input[idx9 * block_size];
      const at::BFloat16* const ip10 = &input[idx10 * block_size];
      const at::BFloat16* const ip11 = &input[idx11 * block_size];
      const at::BFloat16* const ip12 = &input[idx12 * block_size];
      const at::BFloat16* const ip13 = &input[idx13 * block_size];
      const at::BFloat16* const ip14 = &input[idx14 * block_size];
      const at::BFloat16* const ip15 = &input[idx15 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        auto input1 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
        auto input2 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
        auto input3 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
        auto input4 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
        auto input5 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
        auto input6 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
        auto input7 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
        auto input8 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip8[k])), 16));
        auto input9 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip9[k])), 16));
        auto input10 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip10[k])), 16));
        auto input11 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip11[k])), 16));
        auto input12 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip12[k])), 16));
        auto input13 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip13[k])), 16));
        auto input14 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip14[k])), 16));
        auto input15 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip15[k])), 16));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        output = svmla_x(svAll, output, input4, wgt4);
        output = svmla_x(svAll, output, input5, wgt5);
        output = svmla_x(svAll, output, input6, wgt6);
        output = svmla_x(svAll, output, input7, wgt7);
        output = svmla_x(svAll, output, input8, wgt8);
        output = svmla_x(svAll, output, input9, wgt9);
        output = svmla_x(svAll, output, input10, wgt10);
        output = svmla_x(svAll, output, input11, wgt11);
        output = svmla_x(svAll, output, input12, wgt12);
        output = svmla_x(svAll, output, input13, wgt13);
        output = svmla_x(svAll, output, input14, wgt14);
        output = svmla_x(svAll, output, input15, wgt15);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        auto input1 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
        auto input2 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
        auto input3 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
        auto input4 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
        auto input5 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
        auto input6 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
        auto input7 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
        auto input8 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip8[k])), 16));
        auto input9 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip9[k])), 16));
        auto input10 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip10[k])), 16));
        auto input11 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip11[k])), 16));
        auto input12 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip12[k])), 16));
        auto input13 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip13[k])), 16));
        auto input14 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip14[k])), 16));
        auto input15 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip15[k])), 16));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        output = svmla_x(pg, output, input4, wgt4);
        output = svmla_x(pg, output, input5, wgt5);
        output = svmla_x(pg, output, input6, wgt6);
        output = svmla_x(pg, output, input7, wgt7);
        output = svmla_x(pg, output, input8, wgt8);
        output = svmla_x(pg, output, input9, wgt9);
        output = svmla_x(pg, output, input10, wgt10);
        output = svmla_x(pg, output, input11, wgt11);
        output = svmla_x(pg, output, input12, wgt12);
        output = svmla_x(pg, output, input13, wgt13);
        output = svmla_x(pg, output, input14, wgt14);
        output = svmla_x(pg, output, input15, wgt15);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 16;
      pos += 16;
    }
    // unrolling 8 times
    while (j + 7 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      const auto idx4 = indices[pos + 4];
      const auto idx5 = indices[pos + 5];
      const auto idx6 = indices[pos + 6];
      const auto idx7 = indices[pos + 7];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      if (idx4 < 0 || idx4 >= data_size) {
        return false;
      }
      if (idx5 < 0 || idx5 >= data_size) {
        return false;
      }
      if (idx6 < 0 || idx6 >= data_size) {
        return false;
      }
      if (idx7 < 0 || idx7 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float wgt4 = 1.f;
      float wgt5 = 1.f;
      float wgt6 = 1.f;
      float wgt7 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
        wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
        wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
        wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
        wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
      }
      const at::BFloat16* const ip0 = &input[idx0 * block_size];
      const at::BFloat16* const ip1 = &input[idx1 * block_size];
      const at::BFloat16* const ip2 = &input[idx2 * block_size];
      const at::BFloat16* const ip3 = &input[idx3 * block_size];
      const at::BFloat16* const ip4 = &input[idx4 * block_size];
      const at::BFloat16* const ip5 = &input[idx5 * block_size];
      const at::BFloat16* const ip6 = &input[idx6 * block_size];
      const at::BFloat16* const ip7 = &input[idx7 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        auto input1 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
        auto input2 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
        auto input3 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
        auto input4 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
        auto input5 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
        auto input6 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
        auto input7 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        output = svmla_x(svAll, output, input4, wgt4);
        output = svmla_x(svAll, output, input5, wgt5);
        output = svmla_x(svAll, output, input6, wgt6);
        output = svmla_x(svAll, output, input7, wgt7);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        auto input1 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
        auto input2 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
        auto input3 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
        auto input4 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
        auto input5 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
        auto input6 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
        auto input7 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        output = svmla_x(pg, output, input4, wgt4);
        output = svmla_x(pg, output, input5, wgt5);
        output = svmla_x(pg, output, input6, wgt6);
        output = svmla_x(pg, output, input7, wgt7);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 8;
      pos += 8;
    }
    // unrolling 4 times
    while (j + 3 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
      }
      const at::BFloat16* const ip0 = &input[idx0 * block_size];
      const at::BFloat16* const ip1 = &input[idx1 * block_size];
      const at::BFloat16* const ip2 = &input[idx2 * block_size];
      const at::BFloat16* const ip3 = &input[idx3 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        auto input1 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
        auto input2 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
        auto input3 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        auto input1 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
        auto input2 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
        auto input3 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 4;
      pos += 4;
    }
    // unrolling 2 times
    while (j + 1 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
      }
      const at::BFloat16* const ip0 = &input[idx0 * block_size];
      const at::BFloat16* const ip1 = &input[idx1 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        auto input1 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        auto input1 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 2;
      pos += 2;
    }
    // tail loop
    if (j < end_offset) {
      const auto idx0 = indices[pos + 0];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
      }
      const at::BFloat16* const ip0 = &input[idx0 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(svAll,
          svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        output = svmla_x(svAll, output, input0, wgt0);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        auto input0 = svreinterpret_f32(svlsl_x(pg,
          svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
        output = svmla_x(pg, output, input0, wgt0);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      pos ++;
    }
    const int64_t length = end_offset - start_offset;

    if (normalize_by_lengths && length != 0) {
      const float len_inv = 1.0f / length;
      svbool_t pg;
      int64_t j = 0;
      while (j + vLen - 1 < block_size) {
        svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
        j += vLen;
      }
      if (j < block_size) {
        pg = svwhilelt_b32_s64(j, block_size);
        svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
      }
    }
  }
  return pos == index_size;
}
bool EmbeddingLookupIdx_int64_t_bfloat16_float_false__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::BFloat16* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int64_t_bfloat16_float__sve<false>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}
bool EmbeddingLookupIdx_int64_t_bfloat16_float_true__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const at::BFloat16* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int64_t_bfloat16_float__sve<true>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}

template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int32_t_uint8_t_float__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const uint8_t* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  const svbool_t svAll = svptrue_b32();
  const auto vLen = static_cast<int64_t>(svcntw());
  int64_t pos = 0;
  for (int64_t i = 0; i < output_size; ++i) {
    float* const op = &out[i * block_size];
    memset(op, 0, sizeof(float) * block_size);
    if (pos != offsets[i] - offsets[0]) {
      return false;
    }
    int64_t start_offset = offsets[i];
    int64_t end_offset = offsets[i + 1];
    int64_t j = start_offset;
    // unrolling 16 times
    while (j + 15 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      const auto idx4 = indices[pos + 4];
      const auto idx5 = indices[pos + 5];
      const auto idx6 = indices[pos + 6];
      const auto idx7 = indices[pos + 7];
      const auto idx8 = indices[pos + 8];
      const auto idx9 = indices[pos + 9];
      const auto idx10 = indices[pos + 10];
      const auto idx11 = indices[pos + 11];
      const auto idx12 = indices[pos + 12];
      const auto idx13 = indices[pos + 13];
      const auto idx14 = indices[pos + 14];
      const auto idx15 = indices[pos + 15];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      if (idx4 < 0 || idx4 >= data_size) {
        return false;
      }
      if (idx5 < 0 || idx5 >= data_size) {
        return false;
      }
      if (idx6 < 0 || idx6 >= data_size) {
        return false;
      }
      if (idx7 < 0 || idx7 >= data_size) {
        return false;
      }
      if (idx8 < 0 || idx8 >= data_size) {
        return false;
      }
      if (idx9 < 0 || idx9 >= data_size) {
        return false;
      }
      if (idx10 < 0 || idx10 >= data_size) {
        return false;
      }
      if (idx11 < 0 || idx11 >= data_size) {
        return false;
      }
      if (idx12 < 0 || idx12 >= data_size) {
        return false;
      }
      if (idx13 < 0 || idx13 >= data_size) {
        return false;
      }
      if (idx14 < 0 || idx14 >= data_size) {
        return false;
      }
      if (idx15 < 0 || idx15 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float wgt4 = 1.f;
      float wgt5 = 1.f;
      float wgt6 = 1.f;
      float wgt7 = 1.f;
      float wgt8 = 1.f;
      float wgt9 = 1.f;
      float wgt10 = 1.f;
      float wgt11 = 1.f;
      float wgt12 = 1.f;
      float wgt13 = 1.f;
      float wgt14 = 1.f;
      float wgt15 = 1.f;
      float bio = 0.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
        wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
        wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
        wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
        wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
        wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
        wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
        wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
        wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
        wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
        wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
        wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
        wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
      }
      if (scale_bias) {
        bio += wgt0 * scale_bias[2 * idx0 + 1];
        wgt0 = wgt0 * scale_bias[2 * idx0];
        bio += wgt1 * scale_bias[2 * idx1 + 1];
        wgt1 = wgt1 * scale_bias[2 * idx1];
        bio += wgt2 * scale_bias[2 * idx2 + 1];
        wgt2 = wgt2 * scale_bias[2 * idx2];
        bio += wgt3 * scale_bias[2 * idx3 + 1];
        wgt3 = wgt3 * scale_bias[2 * idx3];
        bio += wgt4 * scale_bias[2 * idx4 + 1];
        wgt4 = wgt4 * scale_bias[2 * idx4];
        bio += wgt5 * scale_bias[2 * idx5 + 1];
        wgt5 = wgt5 * scale_bias[2 * idx5];
        bio += wgt6 * scale_bias[2 * idx6 + 1];
        wgt6 = wgt6 * scale_bias[2 * idx6];
        bio += wgt7 * scale_bias[2 * idx7 + 1];
        wgt7 = wgt7 * scale_bias[2 * idx7];
        bio += wgt8 * scale_bias[2 * idx8 + 1];
        wgt8 = wgt8 * scale_bias[2 * idx8];
        bio += wgt9 * scale_bias[2 * idx9 + 1];
        wgt9 = wgt9 * scale_bias[2 * idx9];
        bio += wgt10 * scale_bias[2 * idx10 + 1];
        wgt10 = wgt10 * scale_bias[2 * idx10];
        bio += wgt11 * scale_bias[2 * idx11 + 1];
        wgt11 = wgt11 * scale_bias[2 * idx11];
        bio += wgt12 * scale_bias[2 * idx12 + 1];
        wgt12 = wgt12 * scale_bias[2 * idx12];
        bio += wgt13 * scale_bias[2 * idx13 + 1];
        wgt13 = wgt13 * scale_bias[2 * idx13];
        bio += wgt14 * scale_bias[2 * idx14 + 1];
        wgt14 = wgt14 * scale_bias[2 * idx14];
        bio += wgt15 * scale_bias[2 * idx15 + 1];
        wgt15 = wgt15 * scale_bias[2 * idx15];
      }
      const uint8_t* const ip0 = &input[idx0 * block_size];
      const uint8_t* const ip1 = &input[idx1 * block_size];
      const uint8_t* const ip2 = &input[idx2 * block_size];
      const uint8_t* const ip3 = &input[idx3 * block_size];
      const uint8_t* const ip4 = &input[idx4 * block_size];
      const uint8_t* const ip5 = &input[idx5 * block_size];
      const uint8_t* const ip6 = &input[idx6 * block_size];
      const uint8_t* const ip7 = &input[idx7 * block_size];
      const uint8_t* const ip8 = &input[idx8 * block_size];
      const uint8_t* const ip9 = &input[idx9 * block_size];
      const uint8_t* const ip10 = &input[idx10 * block_size];
      const uint8_t* const ip11 = &input[idx11 * block_size];
      const uint8_t* const ip12 = &input[idx12 * block_size];
      const uint8_t* const ip13 = &input[idx13 * block_size];
      const uint8_t* const ip14 = &input[idx14 * block_size];
      const uint8_t* const ip15 = &input[idx15 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svadd_x(svAll, output, bio);
        auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
        auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
        auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
        auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
        auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k]));
        auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k]));
        auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k]));
        auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k]));
        auto input8 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip8[k]));
        auto input9 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip9[k]));
        auto input10 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip10[k]));
        auto input11 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip11[k]));
        auto input12 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip12[k]));
        auto input13 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip13[k]));
        auto input14 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip14[k]));
        auto input15 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip15[k]));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        output = svmla_x(svAll, output, input4, wgt4);
        output = svmla_x(svAll, output, input5, wgt5);
        output = svmla_x(svAll, output, input6, wgt6);
        output = svmla_x(svAll, output, input7, wgt7);
        output = svmla_x(svAll, output, input8, wgt8);
        output = svmla_x(svAll, output, input9, wgt9);
        output = svmla_x(svAll, output, input10, wgt10);
        output = svmla_x(svAll, output, input11, wgt11);
        output = svmla_x(svAll, output, input12, wgt12);
        output = svmla_x(svAll, output, input13, wgt13);
        output = svmla_x(svAll, output, input14, wgt14);
        output = svmla_x(svAll, output, input15, wgt15);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svadd_x(pg, output, bio);
        auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
        auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
        auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
        auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
        auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k]));
        auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k]));
        auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k]));
        auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k]));
        auto input8 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip8[k]));
        auto input9 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip9[k]));
        auto input10 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip10[k]));
        auto input11 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip11[k]));
        auto input12 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip12[k]));
        auto input13 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip13[k]));
        auto input14 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip14[k]));
        auto input15 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip15[k]));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        output = svmla_x(pg, output, input4, wgt4);
        output = svmla_x(pg, output, input5, wgt5);
        output = svmla_x(pg, output, input6, wgt6);
        output = svmla_x(pg, output, input7, wgt7);
        output = svmla_x(pg, output, input8, wgt8);
        output = svmla_x(pg, output, input9, wgt9);
        output = svmla_x(pg, output, input10, wgt10);
        output = svmla_x(pg, output, input11, wgt11);
        output = svmla_x(pg, output, input12, wgt12);
        output = svmla_x(pg, output, input13, wgt13);
        output = svmla_x(pg, output, input14, wgt14);
        output = svmla_x(pg, output, input15, wgt15);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 16;
      pos += 16;
    }
    // unrolling 8 times
    while (j + 7 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      const auto idx4 = indices[pos + 4];
      const auto idx5 = indices[pos + 5];
      const auto idx6 = indices[pos + 6];
      const auto idx7 = indices[pos + 7];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      if (idx4 < 0 || idx4 >= data_size) {
        return false;
      }
      if (idx5 < 0 || idx5 >= data_size) {
        return false;
      }
      if (idx6 < 0 || idx6 >= data_size) {
        return false;
      }
      if (idx7 < 0 || idx7 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float wgt4 = 1.f;
      float wgt5 = 1.f;
      float wgt6 = 1.f;
      float wgt7 = 1.f;
      float bio = 0.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
        wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
        wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
        wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
        wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
      }
      if (scale_bias) {
        bio += wgt0 * scale_bias[2 * idx0 + 1];
        wgt0 = wgt0 * scale_bias[2 * idx0];
        bio += wgt1 * scale_bias[2 * idx1 + 1];
        wgt1 = wgt1 * scale_bias[2 * idx1];
        bio += wgt2 * scale_bias[2 * idx2 + 1];
        wgt2 = wgt2 * scale_bias[2 * idx2];
        bio += wgt3 * scale_bias[2 * idx3 + 1];
        wgt3 = wgt3 * scale_bias[2 * idx3];
        bio += wgt4 * scale_bias[2 * idx4 + 1];
        wgt4 = wgt4 * scale_bias[2 * idx4];
        bio += wgt5 * scale_bias[2 * idx5 + 1];
        wgt5 = wgt5 * scale_bias[2 * idx5];
        bio += wgt6 * scale_bias[2 * idx6 + 1];
        wgt6 = wgt6 * scale_bias[2 * idx6];
        bio += wgt7 * scale_bias[2 * idx7 + 1];
        wgt7 = wgt7 * scale_bias[2 * idx7];
      }
      const uint8_t* const ip0 = &input[idx0 * block_size];
      const uint8_t* const ip1 = &input[idx1 * block_size];
      const uint8_t* const ip2 = &input[idx2 * block_size];
      const uint8_t* const ip3 = &input[idx3 * block_size];
      const uint8_t* const ip4 = &input[idx4 * block_size];
      const uint8_t* const ip5 = &input[idx5 * block_size];
      const uint8_t* const ip6 = &input[idx6 * block_size];
      const uint8_t* const ip7 = &input[idx7 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svadd_x(svAll, output, bio);
        auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
        auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
        auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
        auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
        auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k]));
        auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k]));
        auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k]));
        auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k]));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        output = svmla_x(svAll, output, input4, wgt4);
        output = svmla_x(svAll, output, input5, wgt5);
        output = svmla_x(svAll, output, input6, wgt6);
        output = svmla_x(svAll, output, input7, wgt7);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svadd_x(pg, output, bio);
        auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
        auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
        auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
        auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
        auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k]));
        auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k]));
        auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k]));
        auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k]));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        output = svmla_x(pg, output, input4, wgt4);
        output = svmla_x(pg, output, input5, wgt5);
        output = svmla_x(pg, output, input6, wgt6);
        output = svmla_x(pg, output, input7, wgt7);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 8;
      pos += 8;
    }
    // unrolling 4 times
    while (j + 3 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float bio = 0.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
      }
      if (scale_bias) {
        bio += wgt0 * scale_bias[2 * idx0 + 1];
        wgt0 = wgt0 * scale_bias[2 * idx0];
        bio += wgt1 * scale_bias[2 * idx1 + 1];
        wgt1 = wgt1 * scale_bias[2 * idx1];
        bio += wgt2 * scale_bias[2 * idx2 + 1];
        wgt2 = wgt2 * scale_bias[2 * idx2];
        bio += wgt3 * scale_bias[2 * idx3 + 1];
        wgt3 = wgt3 * scale_bias[2 * idx3];
      }
      const uint8_t* const ip0 = &input[idx0 * block_size];
      const uint8_t* const ip1 = &input[idx1 * block_size];
      const uint8_t* const ip2 = &input[idx2 * block_size];
      const uint8_t* const ip3 = &input[idx3 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svadd_x(svAll, output, bio);
        auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
        auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
        auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
        auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svadd_x(pg, output, bio);
        auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
        auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
        auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
        auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 4;
      pos += 4;
    }
    // unrolling 2 times
    while (j + 1 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float bio = 0.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
      }
      if (scale_bias) {
        bio += wgt0 * scale_bias[2 * idx0 + 1];
        wgt0 = wgt0 * scale_bias[2 * idx0];
        bio += wgt1 * scale_bias[2 * idx1 + 1];
        wgt1 = wgt1 * scale_bias[2 * idx1];
      }
      const uint8_t* const ip0 = &input[idx0 * block_size];
      const uint8_t* const ip1 = &input[idx1 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svadd_x(svAll, output, bio);
        auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
        auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svadd_x(pg, output, bio);
        auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
        auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 2;
      pos += 2;
    }
    // tail loop
    if (j < end_offset) {
      const auto idx0 = indices[pos + 0];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float bio = 0.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
      }
      if (scale_bias) {
        bio += wgt0 * scale_bias[2 * idx0 + 1];
        wgt0 = wgt0 * scale_bias[2 * idx0];
      }
      const uint8_t* const ip0 = &input[idx0 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svadd_x(svAll, output, bio);
        auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
        output = svmla_x(svAll, output, input0, wgt0);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svadd_x(pg, output, bio);
        auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
        output = svmla_x(pg, output, input0, wgt0);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      pos ++;
    }
    const int64_t length = end_offset - start_offset;

    if (normalize_by_lengths && length != 0) {
      const float len_inv = 1.0f / length;
      svbool_t pg;
      int64_t j = 0;
      while (j + vLen - 1 < block_size) {
        svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
        j += vLen;
      }
      if (j < block_size) {
        pg = svwhilelt_b32_s64(j, block_size);
        svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
      }
    }
  }
  return pos == index_size;
}
bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const uint8_t* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int32_t_uint8_t_float__sve<false>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}
bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const uint8_t* input,
    const int32_t* indices,
    const int32_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int32_t_uint8_t_float__sve<true>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}

template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int64_t_uint8_t_float__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const uint8_t* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  const svbool_t svAll = svptrue_b32();
  const auto vLen = static_cast<int64_t>(svcntw());
  int64_t pos = 0;
  for (int64_t i = 0; i < output_size; ++i) {
    float* const op = &out[i * block_size];
    memset(op, 0, sizeof(float) * block_size);
    if (pos != offsets[i] - offsets[0]) {
      return false;
    }
    int64_t start_offset = offsets[i];
    int64_t end_offset = offsets[i + 1];
    int64_t j = start_offset;
    // unrolling 16 times
    while (j + 15 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      const auto idx4 = indices[pos + 4];
      const auto idx5 = indices[pos + 5];
      const auto idx6 = indices[pos + 6];
      const auto idx7 = indices[pos + 7];
      const auto idx8 = indices[pos + 8];
      const auto idx9 = indices[pos + 9];
      const auto idx10 = indices[pos + 10];
      const auto idx11 = indices[pos + 11];
      const auto idx12 = indices[pos + 12];
      const auto idx13 = indices[pos + 13];
      const auto idx14 = indices[pos + 14];
      const auto idx15 = indices[pos + 15];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      if (idx4 < 0 || idx4 >= data_size) {
        return false;
      }
      if (idx5 < 0 || idx5 >= data_size) {
        return false;
      }
      if (idx6 < 0 || idx6 >= data_size) {
        return false;
      }
      if (idx7 < 0 || idx7 >= data_size) {
        return false;
      }
      if (idx8 < 0 || idx8 >= data_size) {
        return false;
      }
      if (idx9 < 0 || idx9 >= data_size) {
        return false;
      }
      if (idx10 < 0 || idx10 >= data_size) {
        return false;
      }
      if (idx11 < 0 || idx11 >= data_size) {
        return false;
      }
      if (idx12 < 0 || idx12 >= data_size) {
        return false;
      }
      if (idx13 < 0 || idx13 >= data_size) {
        return false;
      }
      if (idx14 < 0 || idx14 >= data_size) {
        return false;
      }
      if (idx15 < 0 || idx15 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float wgt4 = 1.f;
      float wgt5 = 1.f;
      float wgt6 = 1.f;
      float wgt7 = 1.f;
      float wgt8 = 1.f;
      float wgt9 = 1.f;
      float wgt10 = 1.f;
      float wgt11 = 1.f;
      float wgt12 = 1.f;
      float wgt13 = 1.f;
      float wgt14 = 1.f;
      float wgt15 = 1.f;
      float bio = 0.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
        wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
        wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
        wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
        wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
        wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
        wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
        wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
        wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
        wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
        wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
        wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
        wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
      }
      if (scale_bias) {
        bio += wgt0 * scale_bias[2 * idx0 + 1];
        wgt0 = wgt0 * scale_bias[2 * idx0];
        bio += wgt1 * scale_bias[2 * idx1 + 1];
        wgt1 = wgt1 * scale_bias[2 * idx1];
        bio += wgt2 * scale_bias[2 * idx2 + 1];
        wgt2 = wgt2 * scale_bias[2 * idx2];
        bio += wgt3 * scale_bias[2 * idx3 + 1];
        wgt3 = wgt3 * scale_bias[2 * idx3];
        bio += wgt4 * scale_bias[2 * idx4 + 1];
        wgt4 = wgt4 * scale_bias[2 * idx4];
        bio += wgt5 * scale_bias[2 * idx5 + 1];
        wgt5 = wgt5 * scale_bias[2 * idx5];
        bio += wgt6 * scale_bias[2 * idx6 + 1];
        wgt6 = wgt6 * scale_bias[2 * idx6];
        bio += wgt7 * scale_bias[2 * idx7 + 1];
        wgt7 = wgt7 * scale_bias[2 * idx7];
        bio += wgt8 * scale_bias[2 * idx8 + 1];
        wgt8 = wgt8 * scale_bias[2 * idx8];
        bio += wgt9 * scale_bias[2 * idx9 + 1];
        wgt9 = wgt9 * scale_bias[2 * idx9];
        bio += wgt10 * scale_bias[2 * idx10 + 1];
        wgt10 = wgt10 * scale_bias[2 * idx10];
        bio += wgt11 * scale_bias[2 * idx11 + 1];
        wgt11 = wgt11 * scale_bias[2 * idx11];
        bio += wgt12 * scale_bias[2 * idx12 + 1];
        wgt12 = wgt12 * scale_bias[2 * idx12];
        bio += wgt13 * scale_bias[2 * idx13 + 1];
        wgt13 = wgt13 * scale_bias[2 * idx13];
        bio += wgt14 * scale_bias[2 * idx14 + 1];
        wgt14 = wgt14 * scale_bias[2 * idx14];
        bio += wgt15 * scale_bias[2 * idx15 + 1];
        wgt15 = wgt15 * scale_bias[2 * idx15];
      }
      const uint8_t* const ip0 = &input[idx0 * block_size];
      const uint8_t* const ip1 = &input[idx1 * block_size];
      const uint8_t* const ip2 = &input[idx2 * block_size];
      const uint8_t* const ip3 = &input[idx3 * block_size];
      const uint8_t* const ip4 = &input[idx4 * block_size];
      const uint8_t* const ip5 = &input[idx5 * block_size];
      const uint8_t* const ip6 = &input[idx6 * block_size];
      const uint8_t* const ip7 = &input[idx7 * block_size];
      const uint8_t* const ip8 = &input[idx8 * block_size];
      const uint8_t* const ip9 = &input[idx9 * block_size];
      const uint8_t* const ip10 = &input[idx10 * block_size];
      const uint8_t* const ip11 = &input[idx11 * block_size];
      const uint8_t* const ip12 = &input[idx12 * block_size];
      const uint8_t* const ip13 = &input[idx13 * block_size];
      const uint8_t* const ip14 = &input[idx14 * block_size];
      const uint8_t* const ip15 = &input[idx15 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svadd_x(svAll, output, bio);
        auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
        auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
        auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
        auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
        auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k]));
        auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k]));
        auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k]));
        auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k]));
        auto input8 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip8[k]));
        auto input9 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip9[k]));
        auto input10 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip10[k]));
        auto input11 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip11[k]));
        auto input12 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip12[k]));
        auto input13 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip13[k]));
        auto input14 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip14[k]));
        auto input15 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip15[k]));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        output = svmla_x(svAll, output, input4, wgt4);
        output = svmla_x(svAll, output, input5, wgt5);
        output = svmla_x(svAll, output, input6, wgt6);
        output = svmla_x(svAll, output, input7, wgt7);
        output = svmla_x(svAll, output, input8, wgt8);
        output = svmla_x(svAll, output, input9, wgt9);
        output = svmla_x(svAll, output, input10, wgt10);
        output = svmla_x(svAll, output, input11, wgt11);
        output = svmla_x(svAll, output, input12, wgt12);
        output = svmla_x(svAll, output, input13, wgt13);
        output = svmla_x(svAll, output, input14, wgt14);
        output = svmla_x(svAll, output, input15, wgt15);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svadd_x(pg, output, bio);
        auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
        auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
        auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
        auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
        auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k]));
        auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k]));
        auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k]));
        auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k]));
        auto input8 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip8[k]));
        auto input9 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip9[k]));
        auto input10 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip10[k]));
        auto input11 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip11[k]));
        auto input12 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip12[k]));
        auto input13 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip13[k]));
        auto input14 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip14[k]));
        auto input15 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip15[k]));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        output = svmla_x(pg, output, input4, wgt4);
        output = svmla_x(pg, output, input5, wgt5);
        output = svmla_x(pg, output, input6, wgt6);
        output = svmla_x(pg, output, input7, wgt7);
        output = svmla_x(pg, output, input8, wgt8);
        output = svmla_x(pg, output, input9, wgt9);
        output = svmla_x(pg, output, input10, wgt10);
        output = svmla_x(pg, output, input11, wgt11);
        output = svmla_x(pg, output, input12, wgt12);
        output = svmla_x(pg, output, input13, wgt13);
        output = svmla_x(pg, output, input14, wgt14);
        output = svmla_x(pg, output, input15, wgt15);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 16;
      pos += 16;
    }
    // unrolling 8 times
    while (j + 7 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      const auto idx4 = indices[pos + 4];
      const auto idx5 = indices[pos + 5];
      const auto idx6 = indices[pos + 6];
      const auto idx7 = indices[pos + 7];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      if (idx4 < 0 || idx4 >= data_size) {
        return false;
      }
      if (idx5 < 0 || idx5 >= data_size) {
        return false;
      }
      if (idx6 < 0 || idx6 >= data_size) {
        return false;
      }
      if (idx7 < 0 || idx7 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float wgt4 = 1.f;
      float wgt5 = 1.f;
      float wgt6 = 1.f;
      float wgt7 = 1.f;
      float bio = 0.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
        wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
        wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
        wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
        wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
      }
      if (scale_bias) {
        bio += wgt0 * scale_bias[2 * idx0 + 1];
        wgt0 = wgt0 * scale_bias[2 * idx0];
        bio += wgt1 * scale_bias[2 * idx1 + 1];
        wgt1 = wgt1 * scale_bias[2 * idx1];
        bio += wgt2 * scale_bias[2 * idx2 + 1];
        wgt2 = wgt2 * scale_bias[2 * idx2];
        bio += wgt3 * scale_bias[2 * idx3 + 1];
        wgt3 = wgt3 * scale_bias[2 * idx3];
        bio += wgt4 * scale_bias[2 * idx4 + 1];
        wgt4 = wgt4 * scale_bias[2 * idx4];
        bio += wgt5 * scale_bias[2 * idx5 + 1];
        wgt5 = wgt5 * scale_bias[2 * idx5];
        bio += wgt6 * scale_bias[2 * idx6 + 1];
        wgt6 = wgt6 * scale_bias[2 * idx6];
        bio += wgt7 * scale_bias[2 * idx7 + 1];
        wgt7 = wgt7 * scale_bias[2 * idx7];
      }
      const uint8_t* const ip0 = &input[idx0 * block_size];
      const uint8_t* const ip1 = &input[idx1 * block_size];
      const uint8_t* const ip2 = &input[idx2 * block_size];
      const uint8_t* const ip3 = &input[idx3 * block_size];
      const uint8_t* const ip4 = &input[idx4 * block_size];
      const uint8_t* const ip5 = &input[idx5 * block_size];
      const uint8_t* const ip6 = &input[idx6 * block_size];
      const uint8_t* const ip7 = &input[idx7 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svadd_x(svAll, output, bio);
        auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
        auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
        auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
        auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
        auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k]));
        auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k]));
        auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k]));
        auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k]));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        output = svmla_x(svAll, output, input4, wgt4);
        output = svmla_x(svAll, output, input5, wgt5);
        output = svmla_x(svAll, output, input6, wgt6);
        output = svmla_x(svAll, output, input7, wgt7);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svadd_x(pg, output, bio);
        auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
        auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
        auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
        auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
        auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k]));
        auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k]));
        auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k]));
        auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k]));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        output = svmla_x(pg, output, input4, wgt4);
        output = svmla_x(pg, output, input5, wgt5);
        output = svmla_x(pg, output, input6, wgt6);
        output = svmla_x(pg, output, input7, wgt7);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 8;
      pos += 8;
    }
    // unrolling 4 times
    while (j + 3 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      const auto idx2 = indices[pos + 2];
      const auto idx3 = indices[pos + 3];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      if (idx2 < 0 || idx2 >= data_size) {
        return false;
      }
      if (idx3 < 0 || idx3 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float wgt2 = 1.f;
      float wgt3 = 1.f;
      float bio = 0.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
        wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
        wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
      }
      if (scale_bias) {
        bio += wgt0 * scale_bias[2 * idx0 + 1];
        wgt0 = wgt0 * scale_bias[2 * idx0];
        bio += wgt1 * scale_bias[2 * idx1 + 1];
        wgt1 = wgt1 * scale_bias[2 * idx1];
        bio += wgt2 * scale_bias[2 * idx2 + 1];
        wgt2 = wgt2 * scale_bias[2 * idx2];
        bio += wgt3 * scale_bias[2 * idx3 + 1];
        wgt3 = wgt3 * scale_bias[2 * idx3];
      }
      const uint8_t* const ip0 = &input[idx0 * block_size];
      const uint8_t* const ip1 = &input[idx1 * block_size];
      const uint8_t* const ip2 = &input[idx2 * block_size];
      const uint8_t* const ip3 = &input[idx3 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svadd_x(svAll, output, bio);
        auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
        auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
        auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
        auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        output = svmla_x(svAll, output, input2, wgt2);
        output = svmla_x(svAll, output, input3, wgt3);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svadd_x(pg, output, bio);
        auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
        auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
        auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
        auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        output = svmla_x(pg, output, input2, wgt2);
        output = svmla_x(pg, output, input3, wgt3);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 4;
      pos += 4;
    }
    // unrolling 2 times
    while (j + 1 < end_offset) {
      const auto idx0 = indices[pos + 0];
      const auto idx1 = indices[pos + 1];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      if (idx1 < 0 || idx1 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float wgt1 = 1.f;
      float bio = 0.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
        wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
      }
      if (scale_bias) {
        bio += wgt0 * scale_bias[2 * idx0 + 1];
        wgt0 = wgt0 * scale_bias[2 * idx0];
        bio += wgt1 * scale_bias[2 * idx1 + 1];
        wgt1 = wgt1 * scale_bias[2 * idx1];
      }
      const uint8_t* const ip0 = &input[idx0 * block_size];
      const uint8_t* const ip1 = &input[idx1 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svadd_x(svAll, output, bio);
        auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
        auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
        output = svmla_x(svAll, output, input0, wgt0);
        output = svmla_x(svAll, output, input1, wgt1);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svadd_x(pg, output, bio);
        auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
        auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
        output = svmla_x(pg, output, input0, wgt0);
        output = svmla_x(pg, output, input1, wgt1);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      j += 2;
      pos += 2;
    }
    // tail loop
    if (j < end_offset) {
      const auto idx0 = indices[pos + 0];
      if (idx0 < 0 || idx0 >= data_size) {
        return false;
      }
      float wgt0 = 1.f;
      float bio = 0.f;
      if (weights) {
        wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
      }
      if (scale_bias) {
        bio += wgt0 * scale_bias[2 * idx0 + 1];
        wgt0 = wgt0 * scale_bias[2 * idx0];
      }
      const uint8_t* const ip0 = &input[idx0 * block_size];
      svbool_t pg;
      int64_t k = 0;
      while (k + vLen - 1 < block_size) {
        auto output = svld1(svAll, &op[k]);
        output = svadd_x(svAll, output, bio);
        auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
        output = svmla_x(svAll, output, input0, wgt0);
        svst1(svAll, &op[k], output);
        k += vLen;
      }
      if (k < block_size) {
        pg = svwhilelt_b32_s64(k, block_size);
        auto output = svld1(pg, &op[k]);
        output = svadd_x(pg, output, bio);
        auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
        output = svmla_x(pg, output, input0, wgt0);
        svst1(pg, &op[k], output);
        k += vLen;
      }
      pos ++;
    }
    const int64_t length = end_offset - start_offset;

    if (normalize_by_lengths && length != 0) {
      const float len_inv = 1.0f / length;
      svbool_t pg;
      int64_t j = 0;
      while (j + vLen - 1 < block_size) {
        svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
        j += vLen;
      }
      if (j < block_size) {
        pg = svwhilelt_b32_s64(j, block_size);
        svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
      }
    }
  }
  return pos == index_size;
}
bool EmbeddingLookupIdx_int64_t_uint8_t_float_false__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const uint8_t* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int64_t_uint8_t_float__sve<false>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}
bool EmbeddingLookupIdx_int64_t_uint8_t_float_true__sve(
    const int64_t block_size,
    const int64_t output_size,
    const int64_t index_size,
    const int64_t data_size,
    const uint8_t* input,
    const int64_t* indices,
    const int64_t* offsets,
    const float* weights,
    const float* scale_bias,
    bool normalize_by_lengths,
    float* out) {
  return EmbeddingLookupIdx_int64_t_uint8_t_float__sve<true>(
      block_size,
      output_size,
      index_size,
      data_size,
      input,
      indices,
      offsets,
      weights,
      scale_bias,
      normalize_by_lengths,
      out);
}

} // namespace caffe2
