#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/core/TensorAccessor.h>
#include <ATen/cuda/detail/KernelUtils.h>
#include <c10/cuda/CUDAException.h>
#include <c10/macros/Macros.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/native/cuda/block_reduce.cuh>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/nll_loss2d_forward_native.h>
#include <ATen/ops/nll_loss2d_backward_native.h>
#endif

namespace at::native {

namespace {

// Returns a contiguous tensor if the source tensor
// is defined. Otherwise returns the undefined
// source tensor unmodified.
inline Tensor optional_contiguous(const Tensor& source) {
  return source.defined() ? source.contiguous() : source;
}

// Returns the address of the first element of a tensor
// or nullptr if the tensor is undefined.
template <typename scalar_t>
inline const scalar_t* optional_data(const Tensor& source) {
  return source.defined() ? source.const_data_ptr<scalar_t>() : nullptr;
}

using at::cuda::detail::CUDA_NUM_THREADS;
using at::cuda::detail::GET_BLOCKS;

// TODO(crcrpar): Think about introducing `canUse32BitIndexMath` and choose int or int64_t for `target`.
template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
__global__ void nll_loss2d_forward_no_reduce_kernel(
  int64_t n_threads,
  PackedTensorAccessor64<scalar_t, 4> input,
  PackedTensorAccessor64<int64_t, 3> target,
  PackedTensorAccessor64<scalar_t, 3> output,
  const scalar_t* weight,
  int64_t ignore_index
) {
  int64_t batch_size = input.size(0);
  int64_t n_classes = input.size(1);
  int64_t H = input.size(2);
  int64_t W = input.size(3);

  CUDA_KERNEL_LOOP(index, n_threads) {
    const int64_t b = index % batch_size;
    const int64_t h = (index / batch_size) % H;
    const int64_t w = (index / (batch_size * H)) % W;

    int64_t cur_target = target[b][h][w];
    if (cur_target == ignore_index) {
      output[b][h][w] = static_cast<scalar_t>(0);
      continue;
    }
    CUDA_KERNEL_ASSERT(cur_target >= 0 && cur_target < n_classes);
    scalar_t value = input[b][cur_target][h][w];
    scalar_t cur_weight = weight != nullptr ? weight[cur_target] : static_cast<scalar_t>(1);
    output[b][h][w] = -value * cur_weight;
  }
}

template <typename scalar_t, typename accscalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
__global__ void nll_loss2d_forward_kernel(
  scalar_t* output,
  scalar_t* total_weight,
  const scalar_t* input,
  const int64_t* target,
  const scalar_t* weight,
  int n_classes,
  int map_nelem,
  int blocks_per_sample,
  int64_t ignore_index) {

  scalar_t cur_weight;
  accscalar_t input_sum = 0;
  accscalar_t acc_weight = 0;

  index_t sample = blockIdx.x / blocks_per_sample;
  index_t toffset = sample * map_nelem;
  index_t ioffset = sample * map_nelem * n_classes;
  int step = blockDim.x * blocks_per_sample;
  for (int i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
       i < map_nelem;
       i += step) {
    index_t t = target[toffset + i];
    if (t != ignore_index) {
      CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
      cur_weight = weight != nullptr ? weight[t] : static_cast<scalar_t>(1);
      const auto input_index = ioffset + i + map_nelem * t;
      CUDA_KERNEL_ASSERT(input_index >= 0);
      input_sum -= input[input_index] * cur_weight;
      acc_weight += cur_weight;
    }
  }

  __shared__ accscalar_t acc_weight_smem[CUDA_NUM_THREADS];
  __shared__ accscalar_t input_sum_smem[CUDA_NUM_THREADS];

  auto acc_weight_ = cuda_utils::BlockReduceSum(acc_weight, acc_weight_smem);
  auto input_sum_ = cuda_utils::BlockReduceSum(input_sum, input_sum_smem);

  if (threadIdx.x == 0) {
    gpuAtomicAdd(total_weight, static_cast<scalar_t>(acc_weight_));
    gpuAtomicAdd(output, static_cast<scalar_t>(input_sum_));
  }
}

template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
__global__ void nll_loss2d_forward_size_average_kernel(
  scalar_t* output,
  const scalar_t* total_weight
) {
  *output /= *total_weight;
}

template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
__global__ void nll_loss2d_backward_no_reduce_kernel(
  int64_t n_threads,
  PackedTensorAccessor64<int64_t, 3> target,
  PackedTensorAccessor64<scalar_t, 3> grad_output,
  PackedTensorAccessor64<scalar_t, 4> grad_input,
  const scalar_t* weight,
  int64_t ignore_index
) {
  int64_t batch_size = target.size(0);
  int64_t H = target.size(1);
  int64_t W = target.size(2);

  CUDA_KERNEL_LOOP(index, n_threads) {
    const int64_t b = index % batch_size;
    const int64_t h = (index / batch_size) % H;
    const int64_t w = (index / (batch_size * H)) % W;

    int64_t cur_target = target[b][h][w];
    if (cur_target == ignore_index) {
      continue;
    }
    scalar_t value = -(weight != nullptr ? weight[cur_target] : static_cast<scalar_t>(1));
    grad_input[b][cur_target][h][w] = value * grad_output[b][h][w];
  }
}

template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
__global__ void nll_loss2d_backward_kernel(
  scalar_t* grad_input,
  const scalar_t* grad_output,
  const int64_t* target,
  const scalar_t* weights,
  const scalar_t* total_weight,
  bool size_average,
  int n_classes,
  int map_nelem,
  int blocks_per_sample,
  int64_t ignore_index
) {
  const auto grad = -(size_average ? *grad_output / *total_weight
                                   : *grad_output);

  const int sample = blockIdx.x / blocks_per_sample;
  const int step = blockDim.x * blocks_per_sample;

  const int toffset = sample * map_nelem;
  const auto* const target_thread = target + toffset;

  const int ioffset = sample * map_nelem * n_classes;
  auto* const grad_input_thread = grad_input + ioffset;

  for (int i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
       i < map_nelem;
       i += step) {
    const int64_t t = target_thread[i];
    if (t != ignore_index) {
      CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
      const auto grad_input_index = i + map_nelem * t;
      CUDA_KERNEL_ASSERT(grad_input_index >= 0);
      grad_input_thread[i + map_nelem * t] = weights != nullptr ? weights[t] * grad
                                                                : grad;
    }
  }
}

void check_inputs_nll_loss2d(
    const Tensor& input,
    const Tensor& target,
    const Tensor& weight) {
  TORCH_CHECK(
      target.dim() == 3,
      "only batches of spatial targets supported (3D tensors)"
      " but got targets of size: : ",
      target.sizes());
  TORCH_CHECK(
      input.dim() == 4,
      "only batches of spatial inputs supported (4D tensors), "
      "but got input of size: ",
      input.sizes());
  TORCH_CHECK(
      !weight.defined() || weight.numel() == input.size(1),
      "weight tensor should be defined either for all or no classes");

  TORCH_CHECK(
      input.size(0) == target.size(0) && input.size(2) == target.size(1) &&
          input.size(3) == target.size(2),
      "input and target batch or spatial sizes don't match: target ",
      target.sizes(),
      ", input ",
      input.sizes());
}

void nll_loss2d_forward_out_cuda_template(
    Tensor& output,
    Tensor& total_weight,
    const Tensor& input,
    const Tensor& target,
    const std::optional<Tensor>& weight_opt,
    int64_t reduction,
    int64_t ignore_index) {
  // See Note [Writing Nondeterministic Operations]
  // Nondeterministic because of atomicAdd usage in 'sum' or 'mean' reductions.
  if (reduction != at::Reduction::None) {
    at::globalContext().alertNotDeterministic("nll_loss2d_forward_out_cuda_template");
  }

  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> weight_maybe_owned =
      at::borrow_from_optional_tensor(weight_opt);
  const Tensor& weight = *weight_maybe_owned;

  check_inputs_nll_loss2d(input, target, weight);
  total_weight.resize_({});

  if (reduction == at::Reduction::None) {
    int64_t batch_size = input.size(0);
    int64_t H = input.size(2);
    int64_t W = input.size(3);
    int64_t count = batch_size * H * W;

    at::native::resize_output(output, {batch_size, H, W});
    if (count == 0) {
      // This guards from unnecessary operations and launching CUDA kernel with
      // 0 blocks.
      return;
    }
    auto weight_ = optional_contiguous(weight);
    AT_DISPATCH_FLOATING_TYPES_AND2(
        at::ScalarType::Half,
        at::ScalarType::BFloat16,
        input.scalar_type(),
        "nll_loss2d_forward_no_reduce_kernel",
        [&] {
          nll_loss2d_forward_no_reduce_kernel<scalar_t>
              <<<GET_BLOCKS(count),
                 CUDA_NUM_THREADS,
                 0,
                 at::cuda::getCurrentCUDAStream()>>>(
                  count,
                  input.packed_accessor64<scalar_t, 4>(),
                  target.packed_accessor64<int64_t, 3>(),
                  output.packed_accessor64<scalar_t, 3>(),
                  optional_data<scalar_t>(weight_),
                  ignore_index);
          C10_CUDA_KERNEL_LAUNCH_CHECK();
        });
    return;
  }

  // produce scalar outputs for the reduction case
  at::native::resize_output(output, {});

  if (target.numel() == 0) {
    // Here target (and input) have zero elements
    // Mean reduction on empty tensors produces NaN. See the discussion in
    // https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
    if (reduction == Reduction::Mean) {
      output.fill_(std::numeric_limits<double>::quiet_NaN());
    } else {
      output.zero_();
    }
    total_weight.zero_();
    return;
  }

  auto input_ = input.contiguous();
  auto weight_ = optional_contiguous(weight);
  auto target_ = target.contiguous();

  output.zero_();
  total_weight.zero_();

  auto batch_size = target.size(0);
  int64_t map_nelem = target.numel() / batch_size;
  int blocks_per_sample = GET_BLOCKS(map_nelem) / 128;
  blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample;
  int total_blocks = blocks_per_sample * batch_size;

  AT_DISPATCH_FLOATING_TYPES_AND2(
      at::ScalarType::Half,
      at::ScalarType::BFloat16,
      input.scalar_type(),
      "nll_loss2d_forward_kernel",
      [&] {
        using accscalar_t = acc_type<scalar_t, true>;
    AT_DISPATCH_INDEX_TYPES(
        at::native::canUse32BitIndexMath(input_, INT_MAX) ? ScalarType::Int : ScalarType::Long,
        "nll_loss2d_forward_launcher", [&] {
            nll_loss2d_forward_kernel<scalar_t, accscalar_t, index_t>
                <<<total_blocks,
                  CUDA_NUM_THREADS,
                  0,
                  at::cuda::getCurrentCUDAStream()>>>(
                    output.mutable_data_ptr<scalar_t>(),
                    total_weight.mutable_data_ptr<scalar_t>(),
                    input_.const_data_ptr<scalar_t>(),
                    target_.const_data_ptr<int64_t>(),
                    optional_data<scalar_t>(weight_),
                    input_.size(1),
                    input_.size(2) * input_.size(3),
                    blocks_per_sample,
                    ignore_index);
            C10_CUDA_KERNEL_LAUNCH_CHECK();
            // Divide by total_weight
            if (reduction == at::Reduction::Mean) {
              nll_loss2d_forward_size_average_kernel<scalar_t>
                  <<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
                      output.mutable_data_ptr<scalar_t>(),
                      total_weight.const_data_ptr<scalar_t>());
              C10_CUDA_KERNEL_LAUNCH_CHECK();
            }
    });
      });
}

void nll_loss2d_backward_out_cuda_template(
    Tensor& grad_input,
    const Tensor& grad_output,
    const Tensor& input,
    const Tensor& target,
    const std::optional<Tensor>& weight_opt,
    int64_t reduction,
    int64_t ignore_index,
    const Tensor& total_weight) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> weight_maybe_owned =
      at::borrow_from_optional_tensor(weight_opt);
  const Tensor& weight = *weight_maybe_owned;

  check_inputs_nll_loss2d(input, target, weight);
  grad_input.resize_as_(input);
  grad_input.zero_();
  TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
  TORCH_CHECK(
      total_weight.numel() == 1,
      "expected total_weight to be a single element tensor, got: ",
      total_weight.sizes(),
      " (",
      total_weight.numel(),
      " elements)");


  if (reduction == at::Reduction::None) {
    TORCH_CHECK(
        grad_output.dim() == 3,
        "grad_output must have same dimension as target (3) but got dimension: ",
        grad_output.sizes());
    TORCH_CHECK(
        grad_output.size(0) == target.size(0) &&
            grad_output.size(1) == target.size(1) &&
            grad_output.size(2) == target.size(2),
        "grad_output sizes don't match target sizes: target ",
        target.sizes(),
        ", grad_output ",
        grad_output.sizes())
    int64_t batch_size = input.size(0);
    int64_t H = input.size(2);
    int64_t W = input.size(3);
    int64_t count = batch_size * H * W;

    if (count == 0) {
      // This guards from unnecessary operations and launching CUDA kernel with
      // 0 blocks.
      return;
    }
    auto weight_ = optional_contiguous(weight);
    AT_DISPATCH_FLOATING_TYPES_AND2(
        at::ScalarType::Half,
        at::ScalarType::BFloat16,
        input.scalar_type(),
        "nll_loss2d_backward_no_reduce_kernel",
        [&] {
          nll_loss2d_backward_no_reduce_kernel<scalar_t>
              <<<GET_BLOCKS(count),
                 CUDA_NUM_THREADS,
                 0,
                 at::cuda::getCurrentCUDAStream()>>>(
                  count,
                  target.packed_accessor64<int64_t, 3>(),
                  grad_output.packed_accessor64<scalar_t, 3>(),
                  grad_input.packed_accessor64<scalar_t, 4>(),
                  optional_data<scalar_t>(weight_),
                  ignore_index);
          C10_CUDA_KERNEL_LAUNCH_CHECK();
        });
    return;
  }

  int64_t batch_size = target.size(0);
  auto target_numel = target.numel();
  if (batch_size != 0 && target_numel != 0) {
    // This guards from unnecessary operations and launching CUDA kernel with 1
    // blocks.
    auto target_ = target.contiguous();
    auto weight_ = optional_contiguous(weight);

    int64_t map_nelem = target_numel / batch_size;
    int blocks_per_sample = GET_BLOCKS(map_nelem) / 128;
    blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample;
    int total_blocks = blocks_per_sample * batch_size;

    AT_DISPATCH_FLOATING_TYPES_AND2(
        at::ScalarType::Half,
        at::ScalarType::BFloat16,
        input.scalar_type(),
        "nll_loss2d_backward_kernel",
        [&] {
          nll_loss2d_backward_kernel<scalar_t>
              <<<total_blocks,
                CUDA_NUM_THREADS,
                0,
                at::cuda::getCurrentCUDAStream()>>>(
                  grad_input.mutable_data_ptr<scalar_t>(),
                  grad_output.const_data_ptr<scalar_t>(),
                  target_.const_data_ptr<int64_t>(),
                  optional_data<scalar_t>(weight_),
                  total_weight.const_data_ptr<scalar_t>(),
                  reduction == at::Reduction::Mean,
                  input.size(1),
                  map_nelem,
                  blocks_per_sample,
                  ignore_index);
          C10_CUDA_KERNEL_LAUNCH_CHECK();
        });
  }
}
} // namespace

std::tuple<Tensor&, Tensor&> nll_loss2d_forward_out_cuda(
    const Tensor& self,
    const Tensor& target,
    const std::optional<Tensor>& weight_opt,
    int64_t reduction,
    int64_t ignore_index,
    Tensor& output,
    Tensor& total_weight) {
  nll_loss2d_forward_out_cuda_template(
      output, total_weight, self, target, weight_opt, reduction, ignore_index);
  return std::tuple<Tensor&, Tensor&>(output, total_weight);
}

std::tuple<Tensor, Tensor> nll_loss2d_forward_cuda(
    const Tensor& self,
    const Tensor& target,
    const std::optional<Tensor>& weight_opt,
    int64_t reduction,
    int64_t ignore_index) {
  auto output = at::empty({0}, self.options());
  auto total_weight = at::empty({0}, self.options());
  nll_loss2d_forward_out_cuda_template(
      output, total_weight, self, target, weight_opt, reduction, ignore_index);
  return std::make_tuple(output, total_weight);
}

Tensor& nll_loss2d_backward_out_cuda(
    const Tensor& grad_output,
    const Tensor& self,
    const Tensor& target,
    const std::optional<Tensor>& weight_opt,
    int64_t reduction,
    int64_t ignore_index,
    const Tensor& total_weight,
    Tensor& grad_input) {
  nll_loss2d_backward_out_cuda_template(
      grad_input,
      grad_output,
      self,
      target,
      weight_opt,
      reduction,
      ignore_index,
      total_weight);
  return grad_input;
}

Tensor nll_loss2d_backward_cuda(
    const Tensor& grad_output,
    const Tensor& self,
    const Tensor& target,
    const std::optional<Tensor>& weight_opt,
    int64_t reduction,
    int64_t ignore_index,
    const Tensor& total_weight) {
  auto grad_input = at::empty_like(self);
  nll_loss2d_backward_out_cuda_template(
      grad_input,
      grad_output,
      self,
      target,
      weight_opt,
      reduction,
      ignore_index,
      total_weight);
  return grad_input;
}

} // namespace at::native
