#include <ATen/cuda/cub.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <gtest/gtest.h>

TEST(NumBits, CubTest) {
  using at::cuda::cub::get_num_bits;
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000000UL), 1);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000001UL), 1);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000010UL), 2);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000011UL), 2);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000100UL), 3);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000000111UL), 3);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000001000UL), 4);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000001111UL), 4);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000010000UL), 5);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000011111UL), 5);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000100000UL), 6);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000000111111UL), 6);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000001000000UL), 7);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000001111111UL), 7);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000010000000UL), 8);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000011111111UL), 8);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000100000000UL), 9);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000000111111111UL), 9);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000001000000000UL), 10);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000001111111111UL), 10);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000010000000000UL), 11);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000011111111111UL), 11);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000100000000000UL), 12);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000000111111111111UL), 12);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000001000000000000UL), 13);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000001111111111111UL), 13);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000010000000000000UL), 14);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000011111111111111UL), 14);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000100000000000000UL), 15);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000000111111111111111UL), 15);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000001000000000000000UL), 16);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000001111111111111111UL), 16);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000010000000000000000UL), 17);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000011111111111111111UL), 17);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000100000000000000000UL), 18);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000000111111111111111111UL), 18);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000001000000000000000000UL), 19);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000001111111111111111111UL), 19);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000010000000000000000000UL), 20);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000011111111111111111111UL), 20);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000100000000000000000000UL), 21);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000000111111111111111111111UL), 21);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000001000000000000000000000UL), 22);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000001111111111111111111111UL), 22);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000010000000000000000000000UL), 23);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000011111111111111111111111UL), 23);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000100000000000000000000000UL), 24);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000000111111111111111111111111UL), 24);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000001000000000000000000000000UL), 25);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000001111111111111111111111111UL), 25);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000010000000000000000000000000UL), 26);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000011111111111111111111111111UL), 26);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000100000000000000000000000000UL), 27);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000000111111111111111111111111111UL), 27);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000001000000000000000000000000000UL), 28);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000001111111111111111111111111111UL), 28);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000010000000000000000000000000000UL), 29);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000011111111111111111111111111111UL), 29);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000100000000000000000000000000000UL), 30);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000000111111111111111111111111111111UL), 30);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000001000000000000000000000000000000UL), 31);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000001111111111111111111111111111111UL), 31);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000010000000000000000000000000000000UL), 32);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000011111111111111111111111111111111UL), 32);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000100000000000000000000000000000000UL), 33);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000000111111111111111111111111111111111UL), 33);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000001000000000000000000000000000000000UL), 34);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000001111111111111111111111111111111111UL), 34);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000010000000000000000000000000000000000UL), 35);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000011111111111111111111111111111111111UL), 35);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000100000000000000000000000000000000000UL), 36);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000000111111111111111111111111111111111111UL), 36);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000001000000000000000000000000000000000000UL), 37);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000001111111111111111111111111111111111111UL), 37);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000010000000000000000000000000000000000000UL), 38);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000011111111111111111111111111111111111111UL), 38);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000100000000000000000000000000000000000000UL), 39);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000000111111111111111111111111111111111111111UL), 39);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000001000000000000000000000000000000000000000UL), 40);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000001111111111111111111111111111111111111111UL), 40);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000010000000000000000000000000000000000000000UL), 41);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000011111111111111111111111111111111111111111UL), 41);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000100000000000000000000000000000000000000000UL), 42);
  ASSERT_EQ(get_num_bits(0b0000000000000000000000111111111111111111111111111111111111111111UL), 42);
  ASSERT_EQ(get_num_bits(0b0000000000000000000001000000000000000000000000000000000000000000UL), 43);
  ASSERT_EQ(get_num_bits(0b0000000000000000000001111111111111111111111111111111111111111111UL), 43);
  ASSERT_EQ(get_num_bits(0b0000000000000000000010000000000000000000000000000000000000000000UL), 44);
  ASSERT_EQ(get_num_bits(0b0000000000000000000011111111111111111111111111111111111111111111UL), 44);
  ASSERT_EQ(get_num_bits(0b0000000000000000000100000000000000000000000000000000000000000000UL), 45);
  ASSERT_EQ(get_num_bits(0b0000000000000000000111111111111111111111111111111111111111111111UL), 45);
  ASSERT_EQ(get_num_bits(0b0000000000000000001000000000000000000000000000000000000000000000UL), 46);
  ASSERT_EQ(get_num_bits(0b0000000000000000001111111111111111111111111111111111111111111111UL), 46);
  ASSERT_EQ(get_num_bits(0b0000000000000000010000000000000000000000000000000000000000000000UL), 47);
  ASSERT_EQ(get_num_bits(0b0000000000000000011111111111111111111111111111111111111111111111UL), 47);
  ASSERT_EQ(get_num_bits(0b0000000000000000100000000000000000000000000000000000000000000000UL), 48);
  ASSERT_EQ(get_num_bits(0b0000000000000000111111111111111111111111111111111111111111111111UL), 48);
  ASSERT_EQ(get_num_bits(0b0000000000000001000000000000000000000000000000000000000000000000UL), 49);
  ASSERT_EQ(get_num_bits(0b0000000000000001111111111111111111111111111111111111111111111111UL), 49);
  ASSERT_EQ(get_num_bits(0b0000000000000010000000000000000000000000000000000000000000000000UL), 50);
  ASSERT_EQ(get_num_bits(0b0000000000000011111111111111111111111111111111111111111111111111UL), 50);
  ASSERT_EQ(get_num_bits(0b0000000000000100000000000000000000000000000000000000000000000000UL), 51);
  ASSERT_EQ(get_num_bits(0b0000000000000111111111111111111111111111111111111111111111111111UL), 51);
  ASSERT_EQ(get_num_bits(0b0000000000001000000000000000000000000000000000000000000000000000UL), 52);
  ASSERT_EQ(get_num_bits(0b0000000000001111111111111111111111111111111111111111111111111111UL), 52);
  ASSERT_EQ(get_num_bits(0b0000000000010000000000000000000000000000000000000000000000000000UL), 53);
  ASSERT_EQ(get_num_bits(0b0000000000011111111111111111111111111111111111111111111111111111UL), 53);
  ASSERT_EQ(get_num_bits(0b0000000000100000000000000000000000000000000000000000000000000000UL), 54);
  ASSERT_EQ(get_num_bits(0b0000000000111111111111111111111111111111111111111111111111111111UL), 54);
  ASSERT_EQ(get_num_bits(0b0000000001000000000000000000000000000000000000000000000000000000UL), 55);
  ASSERT_EQ(get_num_bits(0b0000000001111111111111111111111111111111111111111111111111111111UL), 55);
  ASSERT_EQ(get_num_bits(0b0000000010000000000000000000000000000000000000000000000000000000UL), 56);
  ASSERT_EQ(get_num_bits(0b0000000011111111111111111111111111111111111111111111111111111111UL), 56);
  ASSERT_EQ(get_num_bits(0b0000000100000000000000000000000000000000000000000000000000000000UL), 57);
  ASSERT_EQ(get_num_bits(0b0000000111111111111111111111111111111111111111111111111111111111UL), 57);
  ASSERT_EQ(get_num_bits(0b0000001000000000000000000000000000000000000000000000000000000000UL), 58);
  ASSERT_EQ(get_num_bits(0b0000001111111111111111111111111111111111111111111111111111111111UL), 58);
  ASSERT_EQ(get_num_bits(0b0000010000000000000000000000000000000000000000000000000000000000UL), 59);
  ASSERT_EQ(get_num_bits(0b0000011111111111111111111111111111111111111111111111111111111111UL), 59);
  ASSERT_EQ(get_num_bits(0b0000100000000000000000000000000000000000000000000000000000000000UL), 60);
  ASSERT_EQ(get_num_bits(0b0000111111111111111111111111111111111111111111111111111111111111UL), 60);
  ASSERT_EQ(get_num_bits(0b0001000000000000000000000000000000000000000000000000000000000000UL), 61);
  ASSERT_EQ(get_num_bits(0b0001111111111111111111111111111111111111111111111111111111111111UL), 61);
  ASSERT_EQ(get_num_bits(0b0010000000000000000000000000000000000000000000000000000000000000UL), 62);
  ASSERT_EQ(get_num_bits(0b0011111111111111111111111111111111111111111111111111111111111111UL), 62);
  ASSERT_EQ(get_num_bits(0b0100000000000000000000000000000000000000000000000000000000000000UL), 63);
  ASSERT_EQ(get_num_bits(0b0111111111111111111111111111111111111111111111111111111111111111UL), 63);
  ASSERT_EQ(get_num_bits(0b1000000000000000000000000000000000000000000000000000000000000000UL), 64);
  ASSERT_EQ(get_num_bits(0b1111111111111111111111111111111111111111111111111111111111111111UL), 64);
}

__managed__ int input[] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };

TEST(InclusiveScanSplit, CubTest) {
  if (!at::cuda::is_available()) return;
  at::globalContext().lazyInitDevice(
      c10::DeviceType::CUDA); // This is required to use PyTorch's caching
                              // allocator.

  int *output1;
  cudaMallocManaged(&output1, sizeof(int) * 10);

  cudaDeviceSynchronize();
  at::cuda::cub::inclusive_scan<int *, int *, ::at_cuda_detail::cub::Sum, /*max_cub_size=*/2>(
    input, output1, ::at_cuda_detail::cub::Sum(), 10);
  cudaDeviceSynchronize();

  ASSERT_EQ(output1[0], 1);
  ASSERT_EQ(output1[1], 3);
  ASSERT_EQ(output1[2], 6);
  ASSERT_EQ(output1[3], 10);
  ASSERT_EQ(output1[4], 15);
  ASSERT_EQ(output1[5], 21);
  ASSERT_EQ(output1[6], 28);
  ASSERT_EQ(output1[7], 36);
  ASSERT_EQ(output1[8], 45);
  ASSERT_EQ(output1[9], 55);
}

TEST(ExclusiveScanSplit, CubTest) {
  if (!at::cuda::is_available()) return;
  at::globalContext().lazyInitDevice(
      c10::DeviceType::CUDA); // This is required to use PyTorch's caching
                              // allocator.

  int *output2;
  cudaMallocManaged(&output2, sizeof(int) * 10);

  cudaDeviceSynchronize();
  at::cuda::cub::exclusive_scan<int *, int *, ::at_cuda_detail::cub::Sum, int, /*max_cub_size=*/2>(
    input, output2, ::at_cuda_detail::cub::Sum(), 0, 10);
  cudaDeviceSynchronize();

  ASSERT_EQ(output2[0], 0);
  ASSERT_EQ(output2[1], 1);
  ASSERT_EQ(output2[2], 3);
  ASSERT_EQ(output2[3], 6);
  ASSERT_EQ(output2[4], 10);
  ASSERT_EQ(output2[5], 15);
  ASSERT_EQ(output2[6], 21);
  ASSERT_EQ(output2[7], 28);
  ASSERT_EQ(output2[8], 36);
  ASSERT_EQ(output2[9], 45);
}
