#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/AccumulateType.h>
#include <ATen/ceil_div.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/DeviceUtils.cuh>
#include <ATen/native/EmbeddingBag.h>
#include <ATen/TensorUtils.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/arange.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/_embedding_bag_native.h>
#include <ATen/ops/_embedding_bag_forward_only_native.h>
#include <ATen/ops/_embedding_bag_dense_backward_native.h>
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_native.h>
#endif

#include <ATen/cuda/cub.cuh>
#include <ATen/native/cuda/SortingCommon.cuh>
#include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>
#include <ATen/native/cuda/block_reduce.cuh>

#include <c10/macros/Macros.h>

#if CUB_SUPPORTS_SCAN_BY_KEY()
#include <thrust/iterator/reverse_iterator.h>
#endif

namespace at::native {

#if !CUB_SUPPORTS_SCAN_BY_KEY()
template<typename index_t>
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
#endif

namespace {

std::pair<Tensor, Tensor> promoteIndicesAndOffsets(
    const Tensor& indices,
    const Tensor& offsets) {
  const auto commonType =
      promoteTypes(offsets.scalar_type(), indices.scalar_type());
  return {
      indices.scalar_type() == commonType ? indices
                                          : indices.toType(commonType),
      offsets.scalar_type() == commonType ? offsets
                                          : offsets.toType(commonType)};
}

// This kernel assumes that all input tensors except `weight` and
// per_sample_weights are contiguous.
template <typename scalar_t, typename index_t>
__global__ void EmbeddingBag_updateOutputKernel_max(
    const index_t *input, const index_t *offsets, const scalar_t *weight, scalar_t *output,
    index_t *offset2bag, int64_t numIndices, int64_t numBags,
    int64_t featureSize, int64_t weight_stride0, int64_t weight_stride1,
    index_t *bag_size, index_t *max_indices,
    index_t padding_idx, int64_t numRows) {

  // the strategy here is that each bag x feature is handled by a single thread

  int64_t chunksPerBag = ceil_div(featureSize, (int64_t)blockDim.x);
  int64_t numChunks = numBags * chunksPerBag;
  int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y;
  int64_t chunkStride = gridDim.x * blockDim.y;

  for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) {
    int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x;
    if (featureDim < featureSize) {
      int64_t bag = chunk / chunksPerBag;
      const scalar_t *weightFeat = weight + featureDim * weight_stride1;
      int64_t begin = bag == 0 ? 0 : offsets[bag]; // forces first offset to be 0 instead of asserting on it
      int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices;
      CUDA_KERNEL_ASSERT(end >= begin);
      scalar_t weightFeatMax = 0;
      int64_t bag_size_ = 0;
      int64_t maxWord = -1;
      for (int64_t emb = begin; emb < end; emb++) {
        bool pad = (input[emb] == padding_idx);
        CUDA_KERNEL_ASSERT(input[emb] < numRows);
        const int64_t weightRow = input[emb] * weight_stride0;
        scalar_t weightValue = weightFeat[weightRow];
        if (bag_size_ == 0 || weightValue > weightFeatMax) {
          weightFeatMax = pad ? weightFeatMax : weightValue;
          maxWord = pad ? maxWord : input[emb];
        }
        bag_size_ += pad ? 0 : 1;

        if (featureDim == 0) {
          offset2bag[emb] = bag;
        }
      }
      bag_size[bag] = bag_size_;
      max_indices[bag * featureSize + featureDim] = maxWord;
      output[bag * featureSize + featureDim] = weightFeatMax;
    }
  }
}

// This kernel assumes that all input tensors except `weight` and
// per_sample_weights are contiguous.
template <typename scalar_t, typename index_t>
__global__ void EmbeddingBag_updateOutputKernel_sum_mean(
    const index_t *input, const index_t *offsets, const scalar_t *weight, scalar_t *output,
    index_t *offset2bag, int64_t numIndices, int64_t numBags,
    int64_t featureSize, int64_t weight_stride0, int64_t weight_stride1,
    int mode, index_t *bag_size,
    const scalar_t* per_sample_weights, int64_t per_sample_weights_stride,
    index_t padding_idx, int64_t numRows) {

  // the strategy here is that each bag x feature is handled by a single thread

  using accscalar_t = acc_type<scalar_t, true>;
  int64_t chunksPerBag = ceil_div(featureSize, (int64_t)blockDim.x);
  int64_t numChunks = numBags * chunksPerBag;
  int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y;
  int64_t chunkStride = gridDim.x * blockDim.y;

  for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) {
    int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x;
    if (featureDim < featureSize) {
      int64_t bag = chunk / chunksPerBag;
      const scalar_t *weightFeat = weight + featureDim * weight_stride1;
      int64_t begin = bag == 0 ? 0 : offsets[bag]; // forces first offset to be 0 instead of asserting on it
      int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices;
      CUDA_KERNEL_ASSERT(end >= begin);
      accscalar_t weightFeatSum = 0;
      int64_t bag_size_ = 0;
      for (int64_t emb = begin; emb < end; emb++) {
        index_t input_idx = input[emb];
        bool pad = (input_idx == padding_idx);
        CUDA_KERNEL_ASSERT(0 <= input_idx && input_idx < numRows);
        const int64_t weightRow = input_idx * weight_stride0;
        scalar_t weightValue = weightFeat[weightRow];
        weightValue = pad ? static_cast<scalar_t>(0) : weightValue;
        if (per_sample_weights) {
          accscalar_t scaleWeightBy = static_cast<accscalar_t>(
              per_sample_weights[emb * per_sample_weights_stride]);
          weightFeatSum += scaleWeightBy * static_cast<accscalar_t>(weightValue);
        } else {
          weightFeatSum += static_cast<accscalar_t>(weightValue);
        }
        bag_size_ += pad ? 0 : 1;

        if (featureDim == 0) {
          offset2bag[emb] = bag;
        }
      }
      if (mode == static_cast<int64_t>(EmbeddingBagMode::MEAN)) {
        if (bag_size_ != 0) {
          weightFeatSum = weightFeatSum / static_cast<accscalar_t>(bag_size_);
        }
      }
      bag_size[bag] = bag_size_;
      output[bag * featureSize + featureDim] = static_cast<scalar_t>(weightFeatSum);
    }
  }
}

Tensor embedding_bag_backward_cuda_sum_avg(
                                   const Tensor &grad,
                                   const Tensor &indices_,
                                   const Tensor &offset2bag,
                                   const Tensor &bag_size,
                                   int64_t num_weights,
                                   bool scale_grad_by_freq, int64_t mode,
                                   const Tensor& per_sample_weights,
                                   int64_t padding_idx) {
  auto indices = indices_.contiguous();

  ptrdiff_t num_indices = indices.numel();

  if (num_indices == 0) {
    // all empty bags
    return at::zeros({num_weights, grad.size(1)}, grad.options());
  }

  auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  Tensor count;

  AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
    auto range = at::arange(num_indices, indices.options());
    // int64_t nbits = cuda::cub::get_num_bits(num_weights);
    cuda::cub::radix_sort_pairs(
      indices.const_data_ptr<index_t>(), sorted_indices.mutable_data_ptr<index_t>(),
      range.const_data_ptr<index_t>(), orig_indices.mutable_data_ptr<index_t>(),
      num_indices, false/*, 0, nbits*/);
  });

  if (scale_grad_by_freq) {
    count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
#if CUB_SUPPORTS_SCAN_BY_KEY()
    AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
      cudaStream_t stream = at::cuda::getCurrentCUDAStream();

      // Compute an increasing sequence per unique item in sortedIndices:
      // sorted: 2 5 5 5 7 7 8 9 9
      //  count: 1 1 2 3 1 2 1 1 2
      auto sorted_data = sorted_indices.const_data_ptr<index_t>();
      auto count_data = count.mutable_data_ptr<index_t>();
      cuda::cub::inclusive_sum_by_key(
        sorted_data,
        NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::ConstantInputIterator<index_t>(1),
        count_data,
        num_indices
      );

      // Take the maximum of each count per unique key in reverse:
      // sorted: 2 5 5 5 7 7 8 9 9
      //  count: 1 3 3 3 2 2 1 2 2
      cuda::cub::inclusive_scan_by_key(
        thrust::make_reverse_iterator(sorted_data + num_indices),
        thrust::make_reverse_iterator(count_data + num_indices),
        thrust::make_reverse_iterator(count_data + num_indices),
        NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max(),
        num_indices
      );
    });
#else
    AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
      embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
    });
#endif
  }
  return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices,
      count, num_weights, padding_idx, mode == EmbeddingBagMode::MEAN, offset2bag,
      bag_size, per_sample_weights);
}

template <typename scalar_t, typename index_t>
__global__ void EmbeddingBag_accGradParametersKernel_max(
    const index_t *max_indices, const scalar_t *gradOutput,
    scalar_t *gradWeight, int64_t stride, int64_t numBags,
    index_t padding_idx, const index_t numel) {

  using accscalar_t = acc_type<scalar_t, true>;

  int64_t chunksPerBag = ceil_div(stride, (int64_t)blockDim.x);
  int64_t numChunks = numBags * chunksPerBag;
  int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y;
  int64_t chunkStride = gridDim.x * blockDim.y;

  for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) {
    int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x;
    if (featureDim < stride) {
      int64_t bag = chunk / chunksPerBag;

      index_t word_idx = max_indices[bag * stride + featureDim];
      if (word_idx >= 0 && word_idx != padding_idx) {
        // If bag is empty, we have max_indices[idx] set to -1 in forward.
        fastAtomicAdd(
            gradWeight, static_cast<index_t>(word_idx * stride + featureDim),
            numel, gradOutput[bag * stride + featureDim], true);
      }
    }
  }
}

Tensor embedding_bag_backward_cuda_max(const Tensor &grad,
                                   const Tensor &max_indices,
                                   int64_t num_weights,
                                   int64_t padding_idx) {
  // See Note [Writing Nondeterministic Operations]
  // Nondeterministic because of atomicAdd usage
  globalContext().alertNotDeterministic("embedding_bag_backward_cuda_max");

  auto grad_weight = at::zeros({num_weights, grad.size(1)}, grad.options());

  int64_t stride = grad_weight.stride(0);

  int64_t numBags = grad.size(0);

  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

#if defined(USE_ROCM)
  dim3 block = dim3(64, 4);
#else
  dim3 block = dim3(32, 8);
#endif
  int grid = 1024;

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      grad.scalar_type(), "embedding_bag_backward_cuda_max", [&] {
        AT_DISPATCH_INDEX_TYPES(max_indices.scalar_type(), "embedding_bag_backward_cuda_max", [&] () {
          EmbeddingBag_accGradParametersKernel_max<
              scalar_t, index_t><<<grid, block, 0, stream>>>(
              max_indices.const_data_ptr<index_t>(), grad.const_data_ptr<scalar_t>(),
              grad_weight.mutable_data_ptr<scalar_t>(), stride, numBags,
              padding_idx, grad_weight.numel());
        C10_CUDA_KERNEL_LAUNCH_CHECK();
      });
  });

  return grad_weight;
}
}

// Assumes all input tensors are contiguous.
// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
std::tuple<Tensor, Tensor, Tensor, Tensor>
_embedding_bag_forward_only_cuda(const Tensor &weight, const Tensor &indices,
                   const Tensor &offsets, const bool scale_grad_by_freq,
                   const int64_t mode, bool sparse, const std::optional<Tensor>& per_sample_weights_opt,
                   bool include_last_offset, int64_t padding_idx) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
  const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;

  return _embedding_bag_cuda(
      weight,
      indices,
      offsets,
      scale_grad_by_freq,
      mode,
      sparse,
      per_sample_weights,
      include_last_offset,
      padding_idx);
}

// Assumes all input tensors are contiguous.
// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
std::tuple<Tensor, Tensor, Tensor, Tensor>
_embedding_bag_cuda(const Tensor &weight, const Tensor &indices_,
                   const Tensor &offsets_, const bool scale_grad_by_freq,
                   const int64_t mode, bool sparse, const std::optional<Tensor>& per_sample_weights_opt,
                   bool include_last_offset, int64_t padding_idx) {
  TORCH_CHECK(indices_.dim() == 1 || indices_.dim() == 2,
      "input has to be a 1D or 2D Tensor, but got Tensor of dimension ",
      indices_.dim());
  if (indices_.dim() == 1) {
    TORCH_CHECK(offsets_.dim() == 1,
        "offsets has to be a 1D Tensor, but got Tensor of dimension ",
        offsets_.dim());
  }
  TORCH_CHECK(weight.dim() == 2,
      "weight has to be a 2D Tensor, but got Tensor of dimension ",
      weight.dim());
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
  const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;

  Tensor indices, offsets;
  std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_);
  auto indices_arg = TensorArg(indices, "indices", 1);
  checkScalarTypes("embedding_bag_cuda", indices_arg, {kLong, kInt});
  auto offsets_arg = TensorArg(offsets, "offsets", 1);
  checkScalarTypes("embedding_bag_cuda", offsets_arg, {kLong, kInt});
  checkSameType("embedding_bag_cuda", indices_arg, offsets_arg);
  auto weight_arg = TensorArg(weight, "weight", 1);
  checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg);
  checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg);

  int64_t numIndices = indices.size(0);
  int64_t numBags = offsets.size(0);
  if (include_last_offset) {
    // Check https://github.com/pytorch/pytorch/issues/29019
    // We plan to add one more element in offsets, which is equal to the size of
    // indices. Currently for cuda devices, we still use the legacy
    // implementation even this flag is enabled.
    TORCH_CHECK(
        numBags >= 1, "include_last_offset: numBags should be at least 1");
    numBags -= 1;
  }
  int64_t featureSize = weight.size(1);

  auto bag_size = at::empty(offsets.sizes(), indices.options());
  auto offset2bag =
      at::empty({indices.size(0)}, indices.options()); // offset2bag = [0 0 0 0 0]

  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  auto output = at::empty({numBags, featureSize}, weight.options());

  Tensor max_indices;

  if (mode == EmbeddingBagMode::MAX) {
    max_indices = at::empty({numBags, featureSize}, indices.options());
  } else {
    // No need to allocate if we aren't doing a backwards pass
    max_indices = at::empty({0}, indices.options());
  }

#if defined(USE_ROCM)
  dim3 block = dim3(64, 4);
#else
  dim3 block = dim3(32, 8);
#endif
  int grid = 1024;
  AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weight.scalar_type(), "embedding_bag_cuda", [&] {
    AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cuda", [&] () {
      if (mode == EmbeddingBagMode::MAX) {
        EmbeddingBag_updateOutputKernel_max<scalar_t, index_t><<<grid, block, 0, stream>>>(
            indices.const_data_ptr<index_t>(), offsets.const_data_ptr<index_t>(),
            weight.const_data_ptr<scalar_t>(), output.mutable_data_ptr<scalar_t>(),
            offset2bag.mutable_data_ptr<index_t>(), numIndices, numBags, featureSize,
            weight.stride(0), weight.stride(1), bag_size.mutable_data_ptr<index_t>(),
            max_indices.mutable_data_ptr<index_t>(),
            padding_idx, weight.size(0));
        C10_CUDA_KERNEL_LAUNCH_CHECK();
      } else {
        EmbeddingBag_updateOutputKernel_sum_mean<scalar_t, index_t><<<grid, block, 0, stream>>>(
            indices.const_data_ptr<index_t>(), offsets.const_data_ptr<index_t>(),
            weight.const_data_ptr<scalar_t>(), output.mutable_data_ptr<scalar_t>(),
            offset2bag.mutable_data_ptr<index_t>(), numIndices, numBags, featureSize,
            weight.stride(0), weight.stride(1), mode, bag_size.mutable_data_ptr<index_t>(),
            per_sample_weights.defined() ? per_sample_weights.const_data_ptr<scalar_t>() : NULL,
            per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
            padding_idx, weight.size(0));
        C10_CUDA_KERNEL_LAUNCH_CHECK();
      }
    });
  });

  return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, offset2bag, bag_size, max_indices);
}

Tensor _embedding_bag_dense_backward_cuda(const Tensor &grad_, const Tensor &indices,
                                   const Tensor &offset2bag,
                                   const Tensor &bag_size_,
                                   const Tensor &max_indices,
                                   int64_t num_weights,
                                   bool scale_grad_by_freq, int64_t mode, const std::optional<Tensor>& per_sample_weights_opt,
                                   int64_t padding_idx) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
  const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;

  // indices, offsets and offset2bag are assumed having correct dtypes and
  // contiguous here due to the checks in _embedding_bag_backward in
  // EmbeddingBag.cpp.
  // Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml
  // for more details.

  Tensor grad = grad_.contiguous();
  auto indices_arg = TensorArg(indices, "indices", 1);
  auto grad_arg = TensorArg(grad, "grad", 1);
  checkSameGPU("embedding_bag_cuda", grad_arg, indices_arg);


  switch (static_cast<EmbeddingBagMode>(mode)) {
    case EmbeddingBagMode::SUM:
    case EmbeddingBagMode::MEAN:
      if (mode == EmbeddingBagMode::MEAN)
        AT_ASSERT(!per_sample_weights.defined());
      return embedding_bag_backward_cuda_sum_avg(grad, indices, offset2bag,
              bag_size_, num_weights, scale_grad_by_freq, mode,
              per_sample_weights, padding_idx);

    case EmbeddingBagMode::MAX:
      AT_ASSERT(!per_sample_weights.defined());
      return embedding_bag_backward_cuda_max(grad, max_indices, num_weights,
              padding_idx);

    default:
      TORCH_CHECK(false,
          "Unknown mode for embedding_bag_backward_cuda ", mode);
  }
}

template <typename scalar_t, typename index_t>
__global__ static void _embedding_bag_per_sample_weights_backward_kernel(
    const scalar_t* grad, int64_t grad_stride0, int64_t grad_stride1,
    const scalar_t* weight, int64_t weight_stride0, int64_t weight_stride1,
    const index_t* indices,  // contiguous
    const index_t* offset2bag,  // contiguous
    int64_t num_samples,
    int64_t embedding_features,
    scalar_t* output,
    index_t padding_idx) {
  using accscalar_t = acc_type<scalar_t, true>;
  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
  const int warp = idx / C10_WARP_SIZE;
  const int thread_in_warp = idx % C10_WARP_SIZE;
  const int num_warps = blockDim.x * gridDim.x / C10_WARP_SIZE;

  // Each warp is responsible for the accumulation of one sample.
  // This involves doing one dot product between grad[bag_idx] and weight[embedding_idx].
  for (int sample_idx = warp; sample_idx < num_samples; sample_idx += num_warps) {
    accscalar_t result = 0.;
    const int bag_idx = (int)offset2bag[sample_idx];
    const int embedding_idx = (int)indices[sample_idx];
    if (embedding_idx != padding_idx) {
      for (int feature_idx = thread_in_warp; feature_idx < embedding_features;
          feature_idx += C10_WARP_SIZE) {
        result +=
            grad[grad_stride0 * bag_idx + grad_stride1 * feature_idx] *
            weight[weight_stride0 * embedding_idx + weight_stride1 * feature_idx];
      }
    }
    result = cuda_utils::WarpReduceSum<accscalar_t>(result);
    if (thread_in_warp == 0) {
      output[sample_idx] = result;
    }
  }
}

Tensor _embedding_bag_per_sample_weights_backward_cuda(
    const Tensor& grad,
    const Tensor& weight,  // NB: embedding table, not per_sample_weights
    const Tensor& indices_,
    const Tensor& offsets_,
    const Tensor& offset2bag,
    int64_t mode,
    int64_t padding_idx) {
  TORCH_CHECK(
      mode == EmbeddingBagMode::SUM,
      "embedding_bag_backward: per_sample_weights only supported for mode='sum'");

  AT_ASSERT(grad.dim() == 2);
  auto embedding_features = grad.size(1);

  Tensor indices, offsets;
  std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_);
  AT_ASSERT(indices.dim() == 1);
  auto num_samples = indices.size(0);

  AT_ASSERT(weight.dim() == 2);
  AT_ASSERT(weight.size(1) == embedding_features);

  const int threads_per_block = 512;
  const int warps_per_block = threads_per_block / at::cuda::warp_size();

  dim3 block(threads_per_block);
  dim3 grid((num_samples + warps_per_block - 1) / warps_per_block);

  auto output = at::empty({num_samples}, grad.options());

  // Early return when there is no samples in the batch. This saves unnecessary kernel
  // launch, but also prevents cudaGetLastError() to complain about invalid launch args
  if (num_samples == 0) {
    return output;
  }

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
    grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() {
      AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() {
        _embedding_bag_per_sample_weights_backward_kernel<scalar_t, index_t>
          <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
            grad.const_data_ptr<scalar_t>(), grad.stride(0), grad.stride(1),
            weight.const_data_ptr<scalar_t>(), weight.stride(0), weight.stride(1),
            indices.const_data_ptr<index_t>(),
            offset2bag.const_data_ptr<index_t>(),
            num_samples,
            embedding_features,
            output.mutable_data_ptr<scalar_t>(),
            padding_idx);
        C10_CUDA_KERNEL_LAUNCH_CHECK();
      });
    }
  );
  return output;
}

} // namespace at::native
