/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#define FBGEMM_EXPORTS
#include "fbgemm/FbgemmEmbedding.h"

#include <asmjit/asmjit.h>
#include <cpuinfo.h>
#include <cmath>
#include <iostream>
#include <mutex>
#include <string>
#include <tuple>
#include "./CodeCache.h"
#include "./MaskAvx2.h"
#include "./RefImplementations.h"
#include "fbgemm/SimdUtils.h"
#include "fbgemm/Utils.h"

namespace fbgemm {

namespace {
namespace x86 = asmjit::x86;

template <typename indxType = std::int64_t>
class ReturnFunctionSignature {
 public:
  using jit_sparse_adagrad_kernel = int (*)(
      int num_rows, // number of rows reading
      std::uint64_t param_size, // total number of parameters
      float* w, // input/output parameters
      const float* g, // input gradients
      float* h, // input/output momentums
      const indxType* indices, // indices of each row
      float epsilon,
      float lr,
      const int* mask_avx2,
      float weight_decay,
      const double* counter,
      std::int64_t counter_halflife);
};

template <
    typename indxType = std::int64_t,
    inst_set_t instSet = inst_set_t::avx2>
class GenSparseAdagrad {
 public:
  GenSparseAdagrad() {}
  void genSparseAdagrad(
      x86::Emitter* a,
      int unroll_factor,
      int num_vec_regs_per_block,
      int remainder,
      int prefetch,
      typename simd_info<instSet>::vec_reg_t epsilon_vreg,
      typename simd_info<instSet>::vec_reg_t lr_vreg,
      x86::Ymm mask_vreg,
      typename simd_info<instSet>::vec_reg_t temp_vreg,
      typename simd_info<instSet>::vec_reg_t weight_decay_vreg,
      bool has_weight_decay);

  void genRowwiseSparseAdagrad(
      x86::Emitter* a,
      int block_size,
      int unroll_factor,
      int num_vec_regs_per_block,
      int remainder,
      int prefetch,
      typename simd_info<instSet>::vec_reg_t epsilon_vreg,
      typename simd_info<instSet>::vec_reg_t lr_vreg,
      x86::Ymm mask_vreg,
      typename simd_info<instSet>::vec_reg_t temp_vreg,
      typename simd_info<instSet>::vec_reg_t weight_decay_vreg,
      bool has_weight_decay);

  typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel
  getOrCreate(
      int block_size,
      int prefetch,
      bool rowwise,
      bool has_weight_decay);

 private:
  static asmjit::JitRuntime& runtime() {
    static asmjit::JitRuntime rt; // JIT Runtime for asmjit
    return rt;
  }

  static std::mutex rtMutex_; /// Controll access to runtime;

  // The hash depends on embedding dimension (block size), prefetch distance,
  // rowwise, and has_weight_decay
  static CodeCache<
      std::tuple<int, int, bool, bool>,
      typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel>
      codeCache_; ///< JIT Code Cache for reuse.

  // These are register we share accross SparseAdagrad and RowwiseSparseAdagrad
  x86::Gp w;
  x86::Gp g;
  x86::Gp h;
  x86::Gp indices;
  x86::Gp base_offset;
  x86::Gp temp1_; // loop counter
  x86::Gp temp2_; // prefetch offset
  x86::Gp temp3_; // prefetch offset of moment in rowwise adagrad

  x86::KReg reduce_mask_avx512_;
}; // GenEmbeddingLookup

template <typename indxType, inst_set_t instSet>
std::mutex GenSparseAdagrad<indxType, instSet>::rtMutex_;

template <typename indxType, inst_set_t instSet>
CodeCache<
    std::tuple<int, int, bool, bool>,
    typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel>
    GenSparseAdagrad<indxType, instSet>::codeCache_;

template <typename indxType, inst_set_t instSet>
void GenSparseAdagrad<indxType, instSet>::genSparseAdagrad(
    x86::Emitter* a,
    int unroll_factor,
    int num_vec_regs_per_block,
    int remainder,
    int prefetch,
    typename simd_info<instSet>::vec_reg_t epsilon_vreg,
    typename simd_info<instSet>::vec_reg_t lr_vreg,
    x86::Ymm mask_vreg,
    typename simd_info<instSet>::vec_reg_t temp_vreg,
    typename simd_info<instSet>::vec_reg_t weight_decay_vreg,
    bool has_weight_decay) {
  // NOTE: temp_vreg is defined only when remainder is true and instSet == avx2
  typedef typename simd_info<instSet>::vec_reg_t vec_reg_t;
  constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS;
  for (int vec_idx = 0; vec_idx < num_vec_regs_per_block;
       vec_idx += unroll_factor) {
    int cur_unroll_factor =
        std::min(unroll_factor, num_vec_regs_per_block - vec_idx);

    for (int v = 0; v < cur_unroll_factor; ++v) {
      vec_reg_t out_vreg = vec_reg_t(v);
      vec_reg_t g_vreg = vec_reg_t(v + cur_unroll_factor);

      if (prefetch && ((vec_idx + v) % (64 / (vlen * sizeof(float))) == 0)) {
        // Intel SDE (wrongly) thinks prefetchwt1 is not available in BDW
        a->prefetchw(
            x86::dword_ptr(h, temp2_, 0, (vec_idx + v) * vlen * sizeof(float)));

        a->prefetchw(
            x86::dword_ptr(w, temp2_, 0, (vec_idx + v) * vlen * sizeof(float)));
      }

      auto g_ptr = x86::dword_ptr(g, (vec_idx + v) * vlen * sizeof(float));
      auto h_ptr = x86::dword_ptr(
          h, base_offset, 0, (vec_idx + v) * vlen * sizeof(float));
      auto w_ptr = x86::dword_ptr(
          w, base_offset, 0, (vec_idx + v) * vlen * sizeof(float));
      if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
        if (instSet == inst_set_t::avx2) {
          a->vmaskmovps(g_vreg.ymm(), mask_vreg, g_ptr);
          if (has_weight_decay) {
            // TODO(@taiqing) use a vreg for weights to avoid duplicate indexing
            a->vmaskmovps(temp_vreg.ymm(), mask_vreg, w_ptr);
            a->vfmadd231ps(g_vreg, temp_vreg, weight_decay_vreg);
          }
          a->vmulps(out_vreg, g_vreg, g_vreg);
          a->vmaskmovps(temp_vreg.ymm(), mask_vreg, h_ptr);
          a->vaddps(out_vreg, out_vreg, temp_vreg);

          a->vmaskmovps(h_ptr, mask_vreg, out_vreg.ymm());

          a->vsqrtps(out_vreg, out_vreg);
          a->vaddps(out_vreg, out_vreg, epsilon_vreg);

          a->vmulps(g_vreg, lr_vreg, g_vreg);
          a->vdivps(out_vreg, g_vreg, out_vreg);

          a->vmaskmovps(temp_vreg.ymm(), mask_vreg, w_ptr);
          a->vaddps(out_vreg, out_vreg, temp_vreg);

          a->vmaskmovps(w_ptr, mask_vreg, out_vreg.ymm());
        } else if (instSet == inst_set_t::avx512) {
          a->k(x86::k(1)).vmovups(g_vreg, g_ptr);
          if (has_weight_decay) {
            a->k(x86::k(1)).vfmadd231ps(g_vreg, weight_decay_vreg, w_ptr);
          }
          a->k(x86::k(1)).vmulps(out_vreg, g_vreg, g_vreg);
          a->k(x86::k(1)).vaddps(out_vreg, out_vreg, h_ptr);

          a->k(x86::k(1)).vmovups(h_ptr, out_vreg);

          a->k(x86::k(1)).vsqrtps(out_vreg, out_vreg);
          a->k(x86::k(1)).vaddps(out_vreg, out_vreg, epsilon_vreg);

          a->k(x86::k(1)).vmulps(g_vreg, lr_vreg, g_vreg);
          a->k(x86::k(1)).vdivps(out_vreg, g_vreg, out_vreg);

          a->k(x86::k(1)).vaddps(out_vreg, out_vreg, w_ptr);

          a->k(x86::k(1)).vmovups(w_ptr, out_vreg);
        }
      } else {
        a->vmovups(g_vreg, g_ptr);
        if (has_weight_decay) {
          a->vfmadd231ps(g_vreg, weight_decay_vreg, w_ptr);
        }
        a->vmulps(out_vreg, g_vreg, g_vreg);
        a->vaddps(out_vreg, out_vreg, h_ptr);

        a->vmovups(h_ptr, out_vreg);

        a->vsqrtps(out_vreg, out_vreg);
        a->vaddps(out_vreg, out_vreg, epsilon_vreg);

        a->vmulps(g_vreg, lr_vreg, g_vreg);
        a->vdivps(out_vreg, g_vreg, out_vreg);

        a->vaddps(out_vreg, out_vreg, w_ptr);

        a->vmovups(w_ptr, out_vreg);
      }
    }
  }
}

template <typename indxType, inst_set_t instSet>
void GenSparseAdagrad<indxType, instSet>::genRowwiseSparseAdagrad(
    x86::Emitter* a,
    int block_size,
    int unroll_factor,
    int num_vec_regs_per_block,
    int remainder,
    int prefetch,
    typename simd_info<instSet>::vec_reg_t epsilon_vreg,
    typename simd_info<instSet>::vec_reg_t lr_vreg,
    x86::Ymm mask_vreg,
    typename simd_info<instSet>::vec_reg_t temp_vreg,
    typename simd_info<instSet>::vec_reg_t weight_decay_vreg,
    bool has_weight_decay) {
  typedef typename simd_info<instSet>::vec_reg_t vec_reg_t;
  constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS;

  // Reduce the unroll factor by 1 for partial sum
  --unroll_factor;
  vec_reg_t partial_sum_vreg = vec_reg_t(unroll_factor);

  if (prefetch) {
    a->prefetchw(x86::dword_ptr(h, temp3_));
  }

  bool areIndices64b = std::is_same<indxType, std::int64_t>::value;
  auto indices_ptr = areIndices64b
      ? x86::qword_ptr(
            indices, temp1_, 3) // use of 3 is to muliply by 8 (int64_t)
      : x86::dword_ptr(
            indices, temp1_, 2); // use of 2 is to muliply by 4 (int32_t)
  if (has_weight_decay) {
    // set base_offset for fetching w in the calculation of gradient square sum
    a->imul(
        areIndices64b ? base_offset : base_offset.r32(),
        indices_ptr,
        static_cast<asmjit::Imm>(block_size * sizeof(float)));
  }

  // Even with avx512, we only need to use avx2 registers when computing
  // partial_sum because some instructions we're using like vhaddps
  // are only in avx2.
  constexpr int vlen_avx2 = simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
  int num_vec_regs_per_block_avx2 = (block_size + vlen_avx2 - 1) / vlen_avx2;

  // Use YMM/XMMs with smaller ids for AVX2 specific instructions like vhaddps
  x86::Ymm partial_sum_vreg_avx2(0);
  x86::Xmm partial_sum_xmm0(partial_sum_vreg_avx2.id());

  a->vxorps(
      partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2);

  // TODO: need to do a tree-reduction to fully take advantage of unrolling
  for (int vec_idx = 0; vec_idx < num_vec_regs_per_block_avx2;
       vec_idx += unroll_factor - 1) {
    int cur_unroll_factor =
        std::min(unroll_factor - 1, num_vec_regs_per_block_avx2 - vec_idx);
    for (int v = 0; v < cur_unroll_factor; ++v) {
      x86::Ymm out_vreg = x86::Ymm(v + 1);
      if (has_weight_decay && prefetch &&
          ((vec_idx + v) % (64 / (vlen_avx2 * sizeof(float))) == 0)) {
        a->prefetchw(x86::dword_ptr(
            w, temp2_, 0, (vec_idx + v) * vlen_avx2 * sizeof(float)));
      }

      auto g_ptr = x86::dword_ptr(g, (vec_idx + v) * vlen_avx2 * sizeof(float));
      auto w_ptr = x86::dword_ptr(
          w, base_offset, 0, (vec_idx + v) * vlen_avx2 * sizeof(float));
      if (block_size % simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS &&
          vec_idx + v == num_vec_regs_per_block_avx2 - 1) {
        if (instSet == inst_set_t::avx2) {
          a->vmaskmovps(out_vreg, mask_vreg, g_ptr);
          if (has_weight_decay) {
            a->vmaskmovps(temp_vreg.ymm(), mask_vreg, w_ptr);
            a->vfmadd231ps(out_vreg, temp_vreg, weight_decay_vreg);
          }
        } else {
          a->k(reduce_mask_avx512_).z().vmovups(out_vreg, g_ptr);
          if (has_weight_decay) {
            a->k(reduce_mask_avx512_)
                .z()
                .vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr);
          }
        }
      } else {
        a->vmovups(out_vreg, g_ptr);
        if (has_weight_decay) {
          a->vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr);
        }
      }
      a->vmulps(out_vreg, out_vreg, out_vreg);
      a->vaddps(partial_sum_vreg_avx2, partial_sum_vreg_avx2, out_vreg);
    }
  }
  // Reduce sum to 1 value
  // __m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum);
  // __m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2);
  a->vhaddps(
      partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2);
  a->vhaddps(
      partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2);

  x86::Xmm partial_sum_xmm1(1);

  //_mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3))
  a->movss(partial_sum_xmm1, partial_sum_xmm0);
  //_mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1))
  a->vextractf128(partial_sum_xmm0, partial_sum_vreg_avx2, 1);

  // final_sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) +
  //    _mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1));
  a->addss(partial_sum_xmm0, partial_sum_xmm1);

  // This fragment moves block size (N) to stack and bcasts it to xmm reg
  a->lea(
      x86::rsp,
      x86::dword_ptr(x86::rsp, -1 * static_cast<int>(sizeof(int32_t))));
  a->mov(x86::dword_ptr(x86::rsp), block_size);
  a->vbroadcastss(
      partial_sum_xmm1, x86::dword_ptr(x86::rsp)); // N is partial_sum_xmm1
  a->vcvtdq2ps(partial_sum_xmm1, partial_sum_xmm1);
  a->lea(x86::rsp, x86::dword_ptr(x86::rsp, sizeof(int32_t)));

  if (has_weight_decay) {
    // set base_offset for fetching h
    a->imul(
        areIndices64b ? base_offset : base_offset.r32(),
        indices_ptr,
        static_cast<asmjit::Imm>(sizeof(float)));
  }

  // final_sum /= N
  a->divss(partial_sum_xmm0, partial_sum_xmm1);
  // load h
  a->movss(partial_sum_xmm1, x86::dword_ptr(h, base_offset));
  // *h + final_sum
  a->addss(partial_sum_xmm0, partial_sum_xmm1);
  // store h
  a->movss(x86::dword_ptr(h, base_offset), partial_sum_xmm0);
  // sqrt(hi)
  a->sqrtss(partial_sum_xmm0, partial_sum_xmm0);
  // bcast partial to all of ymm/zmm reg
  a->vpbroadcastd(partial_sum_vreg, partial_sum_xmm0);
  // lr / sqrt(hi) + epsilon
  a->vaddps(partial_sum_vreg, partial_sum_vreg, epsilon_vreg);
  a->vdivps(partial_sum_vreg, lr_vreg, partial_sum_vreg);
  // partial_sum_vreg now has float_step

  // set base_offset for fetching w in updating weights
  a->imul(
      areIndices64b ? base_offset : base_offset.r32(),
      indices_ptr,
      static_cast<asmjit::Imm>(block_size * sizeof(float)));

  for (int vec_idx = 0; vec_idx < num_vec_regs_per_block;
       vec_idx += unroll_factor) {
    int cur_unroll_factor =
        std::min(unroll_factor, num_vec_regs_per_block - vec_idx);

    for (int v = 0; v < cur_unroll_factor; ++v) {
      vec_reg_t out_vreg = vec_reg_t(v);

      if (!has_weight_decay && prefetch &&
          ((vec_idx + v) % (64 / (vlen * sizeof(float))) == 0)) {
        a->prefetchw(
            x86::dword_ptr(w, temp2_, 0, (vec_idx + v) * vlen * sizeof(float)));
      }

      auto g_ptr = x86::dword_ptr(g, (vec_idx + v) * vlen * sizeof(float));
      auto w_ptr = x86::dword_ptr(
          w, base_offset, 0, (vec_idx + v) * vlen * sizeof(float));
      if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
        if (instSet == inst_set_t::avx2) {
          a->vmaskmovps(temp_vreg.ymm(), mask_vreg, g_ptr);
          if (has_weight_decay) {
            a->vmaskmovps(out_vreg.ymm(), mask_vreg, w_ptr);
            // TODO(@taiqing): have vreg for weights
            a->vfmadd231ps(temp_vreg, weight_decay_vreg, out_vreg);
          }
          a->vmulps(temp_vreg, partial_sum_vreg, temp_vreg);

          a->vmaskmovps(out_vreg.ymm(), mask_vreg, w_ptr);
          a->vaddps(out_vreg, temp_vreg, out_vreg);

          a->vmaskmovps(w_ptr, mask_vreg, out_vreg.ymm());
        } else {
          if (has_weight_decay) {
            a->k(x86::k(1)).vmovups(out_vreg, g_ptr);
            a->k(x86::k(1)).vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr);
            a->k(x86::k(1)).vmulps(out_vreg, partial_sum_vreg, out_vreg);
          } else {
            a->k(x86::k(1)).vmulps(out_vreg, partial_sum_vreg, g_ptr);
          }
          a->k(x86::k(1)).vaddps(out_vreg, out_vreg, w_ptr);
          a->k(x86::k(1)).vmovups(w_ptr, out_vreg);
        }
      } else {
        if (has_weight_decay) {
          a->vmovups(out_vreg, g_ptr);
          a->vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr);
          a->vmulps(out_vreg, partial_sum_vreg, out_vreg);
        } else {
          a->vmulps(out_vreg, partial_sum_vreg, g_ptr);
        }
        a->vaddps(out_vreg, out_vreg, w_ptr);
        a->vmovups(w_ptr, out_vreg);
      }
    }
  }
}

template <typename indxType, inst_set_t instSet>
typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel
GenSparseAdagrad<indxType, instSet>::getOrCreate(
    int block_size,
    int prefetch,
    bool rowwise,
    bool has_weight_decay) {
  std::tuple<int, int, bool, bool> kernelSig =
      std::make_tuple(block_size, prefetch, rowwise, has_weight_decay);

  return codeCache_.getOrCreate(
      kernelSig,
      [&]() ->
      typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel {
        asmjit::CodeHolder code;
        code.init(runtime().environment());
        x86::Assembler assembler(&code);
        x86::Emitter* a = assembler.as<x86::Emitter>();
        bool areIndices64b = std::is_same<indxType, std::int64_t>::value;
#if defined(FBGEMM_LOG_CODE)
        std::string filename = "SparseAdagrad";
        filename += "_emd_dim_" + std::to_string(block_size);
        if (rowwise) {
          filename += "_rowwise";
        }
        filename += areIndices64b ? "_64bit" : "_32bit";
        filename += instSet == inst_set_t::avx512 ? "_avx512" : "_avx2";
        if (prefetch) {
          filename += "_prefetch";
        }
        if (has_weight_decay) {
          filename += "weight_decay";
        }
        filename += ".txt";
        FILE* codeLogFile = fopen(filename.c_str(), "w");
        asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogFile);
        code.setLogger(codeLogger);
#endif

        x86::Gpd num_rows = a->zdi().r32();
        x86::Gp param_size = a->zsi();
        w = a->zdx();
        g = a->zcx();
        h = a->gpz(8);
        indices = a->gpz(9);
        x86::Xmm epsilon(0);
        x86::Xmm lr(1);
        x86::Gp mask_avx2 = a->gpz(10);
        x86::Xmm weight_decay(2);
        x86::Gp counter = a->gpz(11);
        x86::Gp counter_halflife = a->gpz(12);

        // reuse mask_avx2 because mask_avx2 is used only at the beginning
        base_offset = a->gpz(10);
        temp1_ = a->gpz(13);
        temp2_ = a->gpz(14);
        temp3_ = a->gpz(15);

        asmjit::FuncDetail func;
        func.init(
            asmjit::FuncSignatureT<
                int, // return type
                int, // num rows
                std::uint64_t, // param_size
                float*, // w
                const float*, // g
                float*, // h
                const indxType*, // indices
                float, // epsilon
                float, // lr
                const int*, // mask_avx2
                float, // weight_decay
                const double*, // counter then counter_halflife
                std::int64_t>(asmjit::CallConvId::kHost),
            a->environment());

        asmjit::FuncFrame frame;
        frame.init(func);

        if (instSet == inst_set_t::avx2) {
          frame.setDirtyRegs(
              asmjit::RegGroup::kVec,
              asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
                  asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
        } else {
          frame.setDirtyRegs(
              asmjit::RegGroup::kVec,
              asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
                  asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
                  asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
                  asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
        }

        frame.setDirtyRegs(
            asmjit::RegGroup::kGp,
            asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));

        asmjit::FuncArgsAssignment args(&func);
        args.assignAll(
            num_rows,
            param_size,
            w,
            g,
            h,
            indices,
            epsilon,
            lr,
            mask_avx2,
            weight_decay,
            counter,
            counter_halflife);

        args.updateFuncFrame(frame);
        frame.finalize();
        a->emitProlog(frame);
        a->emitArgsAssignment(frame, args);

        constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS;
        constexpr int NUM_VEC_REG = simd_info<instSet>::NUM_VEC_REGS;
        int unroll_factor = NUM_VEC_REG;

        typedef typename simd_info<instSet>::vec_reg_t vec_reg_t;

        int num_vec_regs_per_block = (block_size + vlen - 1) / vlen;
        int remainder = block_size % vlen;

        vec_reg_t epsilon_vreg;
        vec_reg_t lr_vreg;
        vec_reg_t weight_decay_vreg;
        vec_reg_t adjusted_weight_decay_vreg;
        x86::Ymm mask_vreg; // mask for avx2
        vec_reg_t
            temp_vreg; // temp vreg for avx2 to handle remainder computation

        --unroll_factor;
        epsilon_vreg = vec_reg_t(unroll_factor);
        --unroll_factor;
        lr_vreg = vec_reg_t(unroll_factor);
        if (has_weight_decay) {
          --unroll_factor;
          weight_decay_vreg = vec_reg_t(unroll_factor);
          --unroll_factor;
          adjusted_weight_decay_vreg = vec_reg_t(unroll_factor);
        }

        if (remainder) {
          if (instSet == inst_set_t::avx2) {
            --unroll_factor;
            temp_vreg = vec_reg_t(unroll_factor);
          }

          // Creating masks for non multiples of vlen iterations
          if (instSet == inst_set_t::avx2) {
            --unroll_factor;
            mask_vreg = x86::Ymm(unroll_factor);
            a->vmovups(mask_vreg, x86::dword_ptr(mask_avx2));
          } else {
            a->mov(temp1_, (1 << remainder) - 1);
            a->kmovw(x86::k(1), temp1_);
          }
        }
        // Need an extra mask for computing sum of gradients
        int remainder_avx2 =
            block_size % simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
        if (remainder_avx2 && instSet == inst_set_t::avx512 && rowwise) {
          reduce_mask_avx512_ = x86::k(2);
          a->mov(temp1_, (1 << remainder_avx2) - 1);
          a->kmovw(reduce_mask_avx512_, temp1_);
        }

        if (!rowwise) {
          unroll_factor = unroll_factor / 2; // accont for g_vreg
        }

        asmjit::Label exit = a->newLabel();
        asmjit::Label LoopRangeIndexBegin = a->newLabel();
        asmjit::Label LoopRangeIndexEnd = a->newLabel();

        a->vpbroadcastd(epsilon_vreg, epsilon);
        a->vpbroadcastd(lr_vreg, lr);
        if (has_weight_decay) {
          a->vpbroadcastd(weight_decay_vreg, weight_decay);
        }

        a->xor_(temp1_, temp1_);

        a->bind(LoopRangeIndexBegin);
        a->cmp(temp1_.r32(), num_rows); // temp1_ is the loop trip counter
        a->jge(LoopRangeIndexEnd);

        auto indices_ptr = areIndices64b
            ? x86::qword_ptr(
                  indices, temp1_, 3) // use of 3 is to muliply by 8 (int64_t)
            : x86::dword_ptr(
                  indices, temp1_, 2); // use of 2 is to muliply by 4 (int32_t)
        a->imul(
            areIndices64b ? base_offset : base_offset.r32(),
            indices_ptr,
            static_cast<asmjit::Imm>(
                (rowwise ? 1 : block_size) * sizeof(float)));

        // Perform this check
        // if (block_size + offsetIdx > param_size) {
        //   return i;
        // }
        if (areIndices64b) {
          a->mov(temp2_, indices_ptr);
        } else {
          a->mov(temp2_.r32(), indices_ptr);
        }

        if (has_weight_decay) {
          // Check counter != nullptr && counter[idx] > 0
          a->vmovaps(adjusted_weight_decay_vreg, weight_decay_vreg);

          asmjit::Label skip_adjust_freq = a->newLabel();

          a->cmp(counter, 0);
          a->je(skip_adjust_freq);

          // temp3_ : counter[idx]
          a->mov(temp3_, x86::qword_ptr(counter, temp2_, 3));
          a->cmp(temp3_, 0);
          a->jle(skip_adjust_freq);

          // OK to use Xmm registers with small ids that are reserved for temp
          // values in the inner most loop.
          vec_reg_t counter_halflife_vreg(0);
          x86::Xmm counter_vreg(1);
          a->cvtsi2sd(counter_halflife_vreg.xmm(), counter_halflife);
          a->movq(counter_vreg, temp3_);
          a->divpd(counter_halflife_vreg.xmm(), counter_vreg);
          a->vcvtpd2ps(
              counter_halflife_vreg.xmm(), counter_halflife_vreg.ymm());
          a->vbroadcastss(counter_halflife_vreg, counter_halflife_vreg.xmm());
          a->vmulps(
              adjusted_weight_decay_vreg,
              adjusted_weight_decay_vreg,
              counter_halflife_vreg);

          a->bind(skip_adjust_freq);
        }

        a->inc(temp2_);
        a->imul(
            temp2_,
            static_cast<asmjit::Imm>(block_size)); //(offsetIdx+1)*blocksize
        a->cmp(temp2_, param_size);
        a->jg(exit);

        if (prefetch) {
          asmjit::Label pref_dist_reset_start = a->newLabel();
          asmjit::Label pref_dist_reset_end = a->newLabel();

          a->mov(temp2_, temp1_);
          a->add(temp2_, prefetch);
          a->cmp(temp2_.r32(), num_rows);
          a->jge(pref_dist_reset_start);

          auto pref_indices_ptr = areIndices64b
              ? x86::qword_ptr(indices, temp2_, 3)
              : x86::dword_ptr(indices, temp2_, 2);
          if (rowwise) {
            a->imul(
                areIndices64b ? temp3_ : temp3_.r32(),
                pref_indices_ptr,
                static_cast<asmjit::Imm>(sizeof(float)));
          }
          a->imul(
              areIndices64b ? temp2_ : temp2_.r32(),
              pref_indices_ptr,
              static_cast<asmjit::Imm>(block_size * sizeof(float)));

          a->jmp(pref_dist_reset_end);

          a->bind(pref_dist_reset_start);
          a->imul(
              areIndices64b ? temp2_ : temp2_.r32(),
              indices_ptr,
              static_cast<asmjit::Imm>(block_size * sizeof(float)));
          if (rowwise) {
            a->imul(
                areIndices64b ? temp3_ : temp3_.r32(),
                indices_ptr,
                static_cast<asmjit::Imm>(sizeof(float)));
          }

          a->bind(pref_dist_reset_end);
        } // prefetch

        if (rowwise) {
          genRowwiseSparseAdagrad(
              a,
              block_size,
              unroll_factor,
              num_vec_regs_per_block,
              remainder,
              prefetch,
              epsilon_vreg,
              lr_vreg,
              mask_vreg,
              temp_vreg,
              adjusted_weight_decay_vreg,
              has_weight_decay);
        } else {
          genSparseAdagrad(
              a,
              unroll_factor,
              num_vec_regs_per_block,
              remainder,
              prefetch,
              epsilon_vreg,
              lr_vreg,
              mask_vreg,
              temp_vreg,
              adjusted_weight_decay_vreg,
              has_weight_decay);
        }

        a->add(g, static_cast<asmjit::Imm>(block_size * sizeof(float)));
        a->inc(temp1_);
        a->jmp(LoopRangeIndexBegin);
        a->bind(LoopRangeIndexEnd);

        a->bind(exit);
        a->mov(x86::eax, temp1_.r32());
        a->emitEpilog(frame);

        typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel
            fn;
        asmjit::Error err;
        {
          std::unique_lock<std::mutex> lock(rtMutex_);
          err = runtime().add(&fn, &code);
        }
        if (err) {
          std::cout << "Error: in fn add" << std::endl;
          return nullptr;
        }

#if defined(FBGEMM_LOG_CODE)
        fclose(codeLogFile);
        delete codeLogger;
#endif
        return fn;
      });
} // getOrCreate

// Specialization for block size 1 internally called by GenerateSparseAdaGrad
template <typename IndexType>
int SparseAdaGradBlockSize1_(
    int num_rows, // number of rows reading
    std::uint64_t param_size, // total number of parameters
    float* w, // input/output parameters
    const float* g, // input gradients
    float* h, // input/output momentums
    const IndexType* indices, // indices of each row
    float epsilon,
    float lr,
    bool rowwise,
    float weight_decay,
    const double* counter,
    std::int64_t counter_halflife) {
  if (weight_decay != 0.0f) {
    for (int i = 0; i < num_rows; ++i) {
      IndexType idx = indices[i];
      if (idx >= static_cast<int64_t>(param_size)) {
        return i;
      }

      float freq = (counter && counter[idx] > 0)
          ? counter_halflife / counter[idx]
          : 1.0f;
      float gi = std::fma(freq * weight_decay, w[idx], g[i]);
      float hi = h[idx] = h[idx] + gi * gi;
      if (rowwise) {
        w[idx] += lr / (std::sqrt(hi) + epsilon) * gi;
      } else {
        w[idx] += lr * gi / (std::sqrt(hi) + epsilon);
      }
    }
  } else {
    for (int i = 0; i < num_rows; ++i) {
      IndexType idx = indices[i];
      if (idx >= static_cast<int64_t>(param_size)) {
        return i;
      }
      float gi = g[i];
      float hi = h[idx] = h[idx] + gi * gi;
      if (rowwise) {
        w[idx] += lr / (std::sqrt(hi) + epsilon) * gi;
      } else {
        w[idx] += lr * gi / (std::sqrt(hi) + epsilon);
      }
    }
  }
  return num_rows;
}

template int SparseAdaGradBlockSize1_(
    int num_rows, // number of rows reading
    std::uint64_t param_size, // total number of parameters
    float* w, // input parameters
    const float* g, // input gradients
    float* h, // input momentums
    const std::int64_t* indices, // indices of each row
    float epsilon,
    float lr,
    bool rowwise,
    float weight_decay,
    const double* counter,
    std::int64_t counter_halflife);

template int SparseAdaGradBlockSize1_(
    int num_rows, // number of rows reading
    std::uint64_t param_size, // total number of parameters
    float* w, // input parameters
    const float* g, // input gradients
    float* h, // input momentums
    const std::int32_t* indices, // indices of each row
    float epsilon,
    float lr,
    bool rowwise,
    float weight_decay,
    const double* counter,
    std::int64_t counter_halflife);

} // namespace

template <typename IndexType>
typename SparseAdaGradSignature<IndexType>::Type GenerateSparseAdaGrad(
    int block_size,
    bool rowwise,
    int prefetch,
    bool use_weight_decay) {
  if (!cpuinfo_initialize()) {
    throw std::runtime_error("Failed to initialize cpuinfo!");
  }

  if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
    if (block_size == 1) {
      return [=](int num_rows, // number of rows reading
                 std::uint64_t param_size, // total number of parameters
                 float* w, // input/output parameters
                 const float* g, // input gradients
                 float* h, // input/output momentums
                 const IndexType* indices, // indices of each row
                 float epsilon,
                 float lr,
                 float weight_decay,
                 const double* counter,
                 std::int64_t counter_halflife) {
        return SparseAdaGradBlockSize1_(
            num_rows,
            param_size,
            w,
            g,
            h,
            indices,
            epsilon,
            lr,
            rowwise,
            weight_decay,
            counter,
            counter_halflife);
      };
    }
    static GenSparseAdagrad<IndexType, inst_set_t::avx2> kernel_generator;
    constexpr int VLEN = simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
    const int* mask_avx2 = &internal::avx2_ps_or_epi32_combined_mask
                               [(VLEN - (block_size % VLEN)) % VLEN];
    const auto original_func = kernel_generator.getOrCreate(
        block_size, prefetch, rowwise, use_weight_decay);
    return [=](int num_rows, // number of rows reading
               std::uint64_t param_size, // total number of parameters
               float* w, // input/output parameters
               const float* g, // input gradients
               float* h, // input/output momentums
               const IndexType* indices, // indices of each row
               float epsilon,
               float lr,
               float weight_decay,
               const double* counter,
               std::int64_t counter_halflife) {
      return original_func(
          num_rows, // number of rows reading
          param_size, // total number of parameters
          w, // input/output parameters
          g, // input gradients
          h, // input/output momentums
          indices, // indices of each row
          epsilon,
          lr,
          mask_avx2,
          weight_decay,
          counter,
          counter_halflife);
    };
  } else {
#ifdef VLOG
    VLOG(0) << "AVX2 or AVX512 not found, taking the slow path";
#endif
    return [=](int num_rows, // number of rows reading
               std::uint64_t param_size, // total number of parameters
               float* w, // input/output parameters
               const float* g, // input gradients
               float* h, // input/output momentums
               const IndexType* indices, // indices of each row
               float epsilon,
               float lr,
               float weight_decay,
               const double* counter,
               std::int64_t counter_halflife) {
      if (rowwise) {
        return rowwise_sparse_adagrad_ref(
            num_rows, // number of rows reading
            block_size, // number of parameters per rows
            param_size, // total number of parameters
            w, // input/output parameters
            g, // input gradients
            h, // input/output momentums
            indices,
            epsilon,
            lr,
            weight_decay,
            counter,
            counter_halflife);
      } else {
        return sparse_adagrad_ref(
            num_rows, // number of rows reading
            block_size, // number of parameters per rows
            param_size, // total number of parameters
            w, // input/output parameters
            g, // input gradients
            h, // input/output momentums
            indices,
            epsilon,
            lr,
            weight_decay,
            counter,
            counter_halflife);
      }
    };
  }
}

template FBGEMM_API typename SparseAdaGradSignature<std::int64_t>::Type
GenerateSparseAdaGrad<std::int64_t>(
    int block_size, // number of parameters per rows
    bool rowwise,
    int prefetch,
    bool use_weight_decay);

template FBGEMM_API typename SparseAdaGradSignature<std::int32_t>::Type
GenerateSparseAdaGrad<std::int32_t>(
    int block_size, // number of parameters per rows
    bool rowwise,
    int prefetch,
    bool use_weight_decay);

} // namespace fbgemm
