#include <c10/core/ScalarType.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/native/Resize.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/cuda/CUDAApplyUtils.cuh>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/bincount_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/histc_native.h>
#include <ATen/ops/zeros.h>
#endif

namespace at {
namespace cuda {
#define RATIO_OF_GMEM_ATOMIC_ADD_TO_SMEM_ATOMIC_ADD 8
#define FOR_KERNEL_LOOP(i, lim)                                      \
  for (IndexType i = blockIdx.x * blockDim.x + threadIdx.x; i < lim; \
       i += gridDim.x * blockDim.x)

/*
  Memory types used for the 3 histogram implementations.
  See `CUDA_tensor_histogram` below.
 */
enum class CUDAHistogramMemoryType { SHARED, GLOBAL };
namespace {
template <typename input_t, typename IndexType>
__device__ static IndexType getBin(
    input_t bVal,
    at::acc_type<input_t, /*is_cuda=*/true> minvalue,
    at::acc_type<input_t, /*is_cuda=*/true> maxvalue,
    int64_t nbins) {
  IndexType bin = (int)(((bVal - minvalue)) * nbins / (maxvalue - minvalue));
  // (only applicable for histc)
  // while each bin is inclusive at the lower end and exclusive at the higher,
  // i.e. [start, end) the last bin is inclusive at both, i.e. [start, end], in
  // order to include maxvalue if exists therefore when bin == nbins, adjust bin
  // to the last bin
  if (bin == nbins)
    bin -= 1;
  return bin;
}
}

/*
  Kernel for computing the histogram of the input.
 */
template <
    typename output_t,
    typename input_t,
    typename IndexType,
    int ADims,
    int PDims,
    int BDims,
    CUDAHistogramMemoryType MemoryType,
    typename Op>
C10_LAUNCH_BOUNDS_1(cuda::getApplyBlockSize())
__global__ void kernelHistogram1D(
    detail::TensorInfo<output_t, IndexType> a, /* output */
    detail::TensorInfo<output_t, IndexType> p, /* partial output */
    detail::TensorInfo<const input_t, IndexType> b, /* input */
    int64_t nbins,
    at::acc_type<input_t, /*is_cuda=*/true> minvalue,
    at::acc_type<input_t, /*is_cuda=*/true> maxvalue,
    IndexType totalElements,
    Op getOp) {
  extern __shared__ unsigned char my_smem[];
  output_t* smem = nullptr;

  if (MemoryType == CUDAHistogramMemoryType::SHARED) {
    ////////////////////////// Shared memory //////////////////////////
    // atomically add to block specific shared memory
    // then atomically add to the global output tensor
    smem = reinterpret_cast<output_t*>(my_smem);
    for (IndexType i = threadIdx.x; i < a.sizes[0]; i += blockDim.x) {
      smem[i] = 0;
    }
    __syncthreads();
    FOR_KERNEL_LOOP(linearIndex, totalElements) {
      // Convert `linearIndex` into an offset of `b`
      const IndexType bOffset =
          detail::IndexToOffset<const input_t, IndexType, BDims>::get(linearIndex, b);
      const auto bVal = b.data[bOffset];
      if (bVal >= minvalue && bVal <= maxvalue) {
        // Use value at `b` as an offset of `smem`
        const IndexType bin =
            getBin<input_t, IndexType>(bVal, minvalue, maxvalue, nbins);
        gpuAtomicAddNoReturn(&smem[bin], getOp(linearIndex));
      }
    }
    __syncthreads();
    // NOTE: atomically update output bin count.
    //   Atomic update is imp since __syncthread() will only synchronize threads
    //   in a given block, not across blocks.
    for (IndexType i = threadIdx.x; i < a.sizes[0]; i += blockDim.x) {
      const IndexType aOffset =
          detail::IndexToOffset<output_t, IndexType, ADims>::get(i, a);
      gpuAtomicAddNoReturn(&a.data[aOffset], smem[i]);
    }

  } else {
    ////////////////////////// Global memory //////////////////////////
    // atomically add to the output tensor
    // compute histogram for the block
    FOR_KERNEL_LOOP(linearIndex, totalElements) {
      // Convert `linearIndex` into an offset of `b`
      const IndexType bOffset =
          detail::IndexToOffset<const input_t, IndexType, BDims>::get(linearIndex, b);
      const auto bVal = b.data[bOffset];
      if (bVal >= minvalue && bVal <= maxvalue) {
        // Use value at `b` as an offset of `a`
        const IndexType bin =
            getBin<input_t, IndexType>(bVal, minvalue, maxvalue, nbins);
        const IndexType aOffset =
            detail::IndexToOffset<output_t, IndexType, ADims>::get(bin, a);
        gpuAtomicAddNoReturn(&a.data[aOffset], getOp(linearIndex));
      }
    }
  }
}

#define HANDLE_CASE(MEMORY_TYPE, WEIGHTS_OP, SHARED_MEM)                 \
  kernelHistogram1D<                                                     \
      output_t,                                                          \
      input_t,                                                           \
      IndexType,                                                         \
      1,                                                                 \
      2,                                                                 \
      -1,                                                                \
      MEMORY_TYPE><<<grid, block, SHARED_MEM, getCurrentCUDAStream()>>>( \
      aInfo,                                                             \
      pInfo,                                                             \
      bInfo,                                                             \
      nbins,                                                             \
      minvalue,                                                          \
      maxvalue,                                                          \
      totalElements,                                                     \
      WEIGHTS_OP);                                                       \
  C10_CUDA_KERNEL_LAUNCH_CHECK();

#define HANDLE_SWITCH_CASE(mType, getOp)                                   \
  switch (mType) {                                                         \
    case CUDAHistogramMemoryType::SHARED:                                  \
      HANDLE_CASE(CUDAHistogramMemoryType::SHARED, getOp, sharedMem);      \
      break;                                                               \
    default:                                                               \
      HANDLE_CASE(CUDAHistogramMemoryType::GLOBAL, getOp, 0);              \
  }

/*
  Calculate the frequency of the input values.

  `a` contains the final output or the histogram.
  Input `b` is assumed to be 1-D non-negative int array.
  `c` optionally contains the weight vector.
  See `help torch.bincount` for details on the math.

  3 implementations based of input size and memory usage:
    case: enough shared mem
        SHARED: Each block atomically adds to it's own **shared** hist copy,
        then atomically updates the global tensor.
    case: no enough shared mem
        GLOBAL: all threads atomically update to a single **global** hist copy.
 */
template <typename output_t, typename input_t, bool HasWeights>
bool CUDA_tensor_histogram(
    at::Tensor a, /* output */
    at::Tensor b, /* input */
    at::Tensor c, /* weights(optional) */
    int64_t nbins,
    at::acc_type<input_t, /*is_cuda=*/true> minvalue,
    at::acc_type<input_t, /*is_cuda=*/true> maxvalue,
    TensorArgType aType = TensorArgType::ReadWrite,
    TensorArgType bType = TensorArgType::ReadOnly,
    TensorArgType cType = TensorArgType::ReadOnly) {
  checkBackend("CUDA_tensor_histogram", {a, b}, Backend::CUDA);
  if (HasWeights) {
    checkBackend("CUDA_tensor_histogram", {c}, Backend::CUDA);
  }
  auto totalElements = b.numel();

  if (totalElements == 0) {
    return false;
  }

  const dim3 block = getApplyBlock();
  dim3 grid;
  auto curDevice = current_device();
  if (curDevice == -1 || !getApplyGrid(totalElements, grid, curDevice)) {
    return false;
  }

  CUDAHistogramMemoryType memType = CUDAHistogramMemoryType::GLOBAL;
  auto maxSharedMem = getCurrentDeviceProperties()->sharedMemPerBlock;
  auto sharedMem = nbins * sizeof(output_t) + 8; // 8 guard bytes
  // determine memory type to use in the kernel
  if (sharedMem < maxSharedMem) {
    // Solve equations:
    // (1) #(smem atomicAdd per SM) = totalElements / min(grid.x, #SM)
    // (2) #(gmem atomicAdd) = grid.x * nbins
    // (3) RATIO_OF_GMEM_ATOMIC_ADD_TO_SMEM_ATOMIC_ADD = #(gmem atomicAdd) / #(smem atomicAdd per SM)
    unsigned optimalGrid = ceil_div<size_t>(RATIO_OF_GMEM_ATOMIC_ADD_TO_SMEM_ATOMIC_ADD * totalElements,
                                            nbins * getCurrentDeviceProperties()->multiProcessorCount);
    if (optimalGrid < (unsigned)getCurrentDeviceProperties()->multiProcessorCount) {
      optimalGrid = 1 + (unsigned)std::sqrt(RATIO_OF_GMEM_ATOMIC_ADD_TO_SMEM_ATOMIC_ADD * totalElements / nbins);
    }
    auto optimalSteps = ceil_div<size_t>(totalElements, optimalGrid * block.x);
    optimalGrid = ceil_div<size_t>(totalElements, optimalSteps * block.x);
    grid.x = std::min(grid.x, optimalGrid);
    memType = CUDAHistogramMemoryType::SHARED;
  }

  using IndexType = int64_t;
  auto aInfo = detail::getTensorInfo<output_t, IndexType>(a);
  auto bInfo = detail::getTensorInfo<const input_t, IndexType>(b);
  detail::TensorInfo<output_t, IndexType> pInfo(nullptr, 0, {}, {});

  if (HasWeights) {
    auto cInfo = detail::getTensorInfo<output_t, IndexType>(c);
    const auto getWeightsOp = [cInfo] __device__(IndexType cIndex) {
      const IndexType cOffset =
          detail::IndexToOffset<output_t, IndexType, 1>::get(cIndex, cInfo);
      return cInfo.data[cOffset];
    };
    HANDLE_SWITCH_CASE(memType, getWeightsOp)
  } else {
    static const auto getDummyOp = [] __device__(IndexType) { return 1L; };
    HANDLE_SWITCH_CASE(memType, getDummyOp)
  }
  return true;
}

#undef HANDLE_CASE
#undef HANDLE_SWITCH_CASE
#undef FOR_KERNEL_LOOP
#undef RATIO_OF_GMEM_ATOMIC_ADD_TO_SMEM_ATOMIC_ADD
} // namespace cuda

namespace {
///////////////// bincount /////////////////
template <typename input_t, typename weights_t>
Tensor _bincount_cuda_template(
    const Tensor& self,
    const Tensor& weights,
    int64_t minlength) {
  if (minlength < 0) {
    TORCH_CHECK(false, "minlength should be >= 0");
  }
  if (self.dim() == 1 && self.numel() == 0) {
    return at::zeros(
        {minlength},
        kLong,
        std::nullopt /* layout */,
        kCUDA,
        std::nullopt /* pin_memory */);
  }
  if (self.dim() != 1 ||
      (!std::is_same_v<input_t, uint8_t> &&
       *self.min().cpu().const_data_ptr<input_t>() < 0)) {
    TORCH_CHECK(false, "bincount only supports 1-d non-negative integral inputs.");
  }

  bool has_weights = weights.defined();
  if (has_weights && (weights.dim() != 1 || weights.size(0) != self.size(0))) {
    TORCH_CHECK(false, "weights should be 1-d and have the same length as input");
  }

  const int64_t nbins =
      std::max(self.max().item<input_t>() + (int64_t)1, minlength);

  // we are using acc_type for the bounds, in particular int64_t for integers
  // in order to avoid overflows (e.g. using 256 bins for dtype uint8)
  using bounds_t = at::acc_type<input_t, /*is_cuda=*/true>;
  const bounds_t minvalue = 0;
  const bounds_t maxvalue = nbins;
  // alloc output counter on GPU
  Tensor output;
  if (has_weights) {
    output = at::zeros(
        {nbins},
        optTypeMetaToScalarType(weights.options().dtype_opt()),
        weights.options().layout_opt(),
        weights.options().device_opt(),
        weights.options().pinned_memory_opt());
    cuda::CUDA_tensor_histogram<weights_t, input_t, true>(
        output, self, weights, nbins, minvalue, maxvalue);
  } else {
    output = at::zeros(
        {nbins},
        kLong,
        std::nullopt /* layout */,
        DeviceType::CUDA,
        std::nullopt /* pin_memory */);
    cuda::CUDA_tensor_histogram<int64_t, input_t, false>(
        output, self, weights, nbins, minvalue, maxvalue);
  }
  return output;
}

///////////////// histc /////////////////
template <typename input_t>
Tensor _histc_cuda_template(
    const Tensor& self,
    int64_t nbins,
    at::acc_type<input_t, /*is_cuda=*/true> min,
    at::acc_type<input_t, /*is_cuda=*/true> max) {
  if (nbins <= 0) {
    TORCH_CHECK(false, "bins must be > 0");
  }
  Tensor output = at::zeros(
      {nbins},
      self.scalar_type(),
      std::nullopt /* layout */,
      DeviceType::CUDA,
      std::nullopt /* pin_memory */);
  using bounds_t = at::acc_type<input_t, /*is_cuda=*/true>;
  bounds_t minvalue = min;
  bounds_t maxvalue = max;

  if (min == max && self.numel() > 0) {
    minvalue = *self.min().cpu().const_data_ptr<input_t>();
    maxvalue = *self.max().cpu().const_data_ptr<input_t>();
  }
  if (minvalue == maxvalue) {
    minvalue = minvalue - 1;
    maxvalue = maxvalue + 1;
  }

// Microsoft's STL has a problem with integer overloads of std::fpclassify used
// by std::isnan and std::isinf, as described here:
// https://stackoverflow.com/questions/61646166/how-to-resolve-fpclassify-ambiguous-call-to-overloaded-function
// This macro provides a workaround for this problem.
#if defined(USE_ROCM) && defined(_MSC_VER)
#define STL_CAST_BUG(value) static_cast<double>(value)
#else
#define STL_CAST_BUG(value) value
#endif

#if !defined(USE_ROCM)
  TORCH_CHECK(
      !(at::_isinf(minvalue) || at::_isinf(maxvalue) ||
        at::_isnan(minvalue) || at::_isnan(maxvalue)),
      "range of [",
      minvalue,
      ", ",
      maxvalue,
      "] is not finite");
#else
  TORCH_CHECK(
      !(std::isinf(STL_CAST_BUG(minvalue)) ||
        std::isinf(STL_CAST_BUG(maxvalue)) ||
        std::isnan(STL_CAST_BUG(minvalue)) ||
        std::isnan(STL_CAST_BUG(maxvalue))),
      "range of [",
      minvalue,
      ", ",
      maxvalue,
      "] is not finite");
#endif
  TORCH_CHECK(minvalue < maxvalue, "max must be larger than min");

  cuda::CUDA_tensor_histogram<input_t, input_t, false>(
      output, self, Tensor(), nbins, minvalue, maxvalue);
  return output;
}
} // namespace

namespace native {
Tensor _bincount_cuda(
    const Tensor& self, const std::optional<Tensor>& weights_opt,
    int64_t minlength) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> weights_maybe_owned = at::borrow_from_optional_tensor(weights_opt);
  const Tensor& weights = *weights_maybe_owned;

  if (weights_opt.has_value()) {
    // See Note [Writing Nondeterministic Operations]
    // Nondeterministic if weights are given, because of floating point
    // atomicAdd usage
    globalContext().alertNotDeterministic("_bincount_cuda");
  }
  return AT_DISPATCH_INTEGRAL_TYPES(self.scalar_type(), "bincount_cuda", [&] {
    const auto scalar = weights.scalar_type();
    if (scalar == ScalarType::Undefined || scalar == ScalarType::Float)
      return _bincount_cuda_template<scalar_t, float>(self, weights, minlength);
    return _bincount_cuda_template<scalar_t, double>(
        self, weights.to(kDouble), minlength);
  });
}

Tensor _histc_cuda(
    const Tensor& self,
    int64_t nbins,
    const Scalar& min,
    const Scalar& max) {
  if (self.scalar_type() == ScalarType::Half) {
    TORCH_CHECK(false, "HalfTensor is not supported");
  }
  // See Note [Writing Nondeterministic Operations]
  // Nondeterministic for floating types because of atomicAdd usage
  if (at::isFloatingType(self.scalar_type())){
    globalContext().alertNotDeterministic("_histc_cuda with floating point input");
  }
  return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "histc", [&] {
    using bounds_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
    return _histc_cuda_template<scalar_t>(
        self, nbins, min.to<bounds_t>(), max.to<bounds_t>());
  });
}

Tensor& _histc_out_cuda(const Tensor& self, int64_t bins, const Scalar& min, const Scalar& max, Tensor& result) {
  auto ret = _histc_cuda(self, bins, min, max);
  resize_output(result, ret.sizes());
  result.copy_(ret);
  return result;
}
} // namespace native
} // namespace at
