#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/Cross.h>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/Dispatch.h>
#include <ATen/core/Tensor.h>

namespace at::native {

template <typename T, typename OffsetCalc, typename StrideType>
__global__ void cross_kernel(
    int numel, T* out, const T* x1, const T* x2, OffsetCalc offset_calculator,
    StrideType ostride, StrideType x1stride, StrideType x2stride) {
  CUDA_KERNEL_LOOP(i, numel) {
    const auto offsets = offset_calculator.get(i);
    auto* out_row = out + offsets[0];
    const auto* x1_row = x1 + offsets[1];
    const auto* x2_row = x2 + offsets[2];

    const T val0 = (x1_row[1 * x1stride] * x2_row[2 * x2stride] -
                    x1_row[2 * x1stride] * x2_row[1 * x2stride]);

    const T val1 = (x1_row[2 * x1stride] * x2_row[0 * x2stride] -
                    x1_row[0 * x1stride] * x2_row[2 * x2stride]);

    const T val2 = (x1_row[0 * x1stride] * x2_row[1 * x2stride] -
                    x1_row[1 * x1stride] * x2_row[0 * x2stride]);


    out_row[0 * ostride] = val0;
    out_row[1 * ostride] = val1;
    out_row[2 * ostride] = val2;
  }
}

void launch_cross_kernel(const TensorIteratorBase& iter, int64_t ostride,
                         int64_t x1stride, int64_t x2stride) {
  const auto N = iter.numel();
  auto offset_calculator = make_element_offset_calculator<3>(iter);
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(N > 0 && N <= std::numeric_limits<int32_t>::max());
  int64_t grid = (N + num_threads() - 1) / num_threads();
  auto stream = at::cuda::getCurrentCUDAStream();

  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, iter.common_dtype(), "cross_cuda", [&] {
    auto out = static_cast<scalar_t*>(iter.data_ptr(0));
    auto x1 = static_cast<const scalar_t*>(iter.data_ptr(1));
    auto x2 = static_cast<const scalar_t*>(iter.data_ptr(2));
    constexpr int64_t int_max = std::numeric_limits<int>::max();
    if (ostride * 2 > int_max || x1stride * 2 > int_max || x2stride * 2 > int_max) {
      cross_kernel<<<grid, num_threads(), 0, stream>>>(
          N, out, x1, x2, offset_calculator, ostride, x1stride, x2stride);
      C10_CUDA_KERNEL_LAUNCH_CHECK();
    } else {
      cross_kernel<<<grid, num_threads(), 0, stream>>>(
          N, out, x1, x2, offset_calculator,
          static_cast<int>(ostride),
          static_cast<int>(x1stride),
          static_cast<int>(x2stride));
      C10_CUDA_KERNEL_LAUNCH_CHECK();
    }
  });
}

void cross_impl(const Tensor& result, const Tensor& x1, const Tensor& x2, int64_t dim) {
  const int64_t ostride = result.stride(dim);
  const int64_t x1stride = x1.stride(dim);
  const int64_t x2stride = x2.stride(dim);

  auto iter = TensorIteratorConfig()
      .add_output(result)
      .add_const_input(x1)
      .add_const_input(x2)
      .resize_outputs(false)
      .declare_static_shape(result.sizes(), /*squash_dims=*/dim)
      .build();

  if (iter.numel() == 0) {
    return;
  }

  if (iter.can_use_32bit_indexing()) {
    launch_cross_kernel(iter, ostride, x1stride, x2stride);
  } else {
    for (auto&& sub_iter: iter.with_32bit_indexing()) {
      launch_cross_kernel(sub_iter, ostride, x1stride, x2stride);
    }
  }
}

REGISTER_DISPATCH(cross_stub, &cross_impl)

} // namespace at::native
