#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/util/intrusive_ptr.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/Types.hpp>
#include <torch/library.h>

namespace c10d {
namespace {

TORCH_LIBRARY(c10d, m) {
  // The following ProcessGroup, Work, and ReduceOp definitions are more like
  // declarations. They don't expose the details of the two classes into
  // TorchScript.
  m.class_<ProcessGroup>("ProcessGroup").def(torch::init<int64_t, int64_t>());
  m.class_<Work>("Work")
      .def(torch::init<>())
      .def("wait", [](const c10::intrusive_ptr<Work>& self) { self->wait(); });
  m.class_<ReduceOp>("ReduceOp").def(torch::init<>());
  m.def(
      "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
  m.def(
      "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
  m.def(
      "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
  m.def(
      "_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)");
  m.def(
      "allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
  m.def(
      "_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)");
  m.def(
      "reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
  m.def(
      "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
  m.def(
      "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout, bool wait_all_ranks) -> ()");
  m.def(
      "send(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int dst, int tag) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "recv_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int src, int tag) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "recv_any_source_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int tag) -> __torch__.torch.classes.c10d.Work");
}
} // namespace

namespace ops {

// Below are ProcessGroup's corresponding ops for each backend. Ops are but
// routed through the dispatcher to be dispatched to the appropriate backend.
// Currently a no-op as the process group does not have a list of backends.

namespace {

#define IMPL_SEND(DEV)                                                        \
  c10::intrusive_ptr<Work> send##DEV(                                         \
      at::TensorList tensors,                                                 \
      const c10::intrusive_ptr<ProcessGroup>& process_group,                  \
      int64_t dstRank,                                                        \
      int64_t tag) {                                                          \
    auto tensor_vec = tensors.vec();                                          \
    return process_group->getBackend(c10::DeviceType::DEV)                    \
        ->send(tensor_vec, static_cast<int>(dstRank), static_cast<int>(tag)); \
  }

IMPL_SEND(CPU)
IMPL_SEND(CUDA)
IMPL_SEND(PrivateUse1)

#define IMPL_RECV(DEV)                                                        \
  c10::intrusive_ptr<Work> recv_##DEV(                                        \
      at::TensorList tensors,                                                 \
      const c10::intrusive_ptr<ProcessGroup>& process_group,                  \
      int64_t srcRank,                                                        \
      int64_t tag) {                                                          \
    auto tensor_vec = tensors.vec();                                          \
    return process_group->getBackend(c10::DeviceType::DEV)                    \
        ->recv(tensor_vec, static_cast<int>(srcRank), static_cast<int>(tag)); \
  }

IMPL_RECV(CPU)
IMPL_RECV(CUDA)
IMPL_RECV(PrivateUse1)

#define IMPL_RECV_ANY_SOURCE(DEV)                            \
  c10::intrusive_ptr<Work> recv_any_source_##DEV(            \
      at::TensorList tensors,                                \
      const c10::intrusive_ptr<ProcessGroup>& process_group, \
      int64_t tag) {                                         \
    auto tensor_vec = tensors.vec();                         \
    return process_group->getBackend(c10::DeviceType::DEV)   \
        ->recvAnysource(tensor_vec, static_cast<int>(tag));  \
  }

IMPL_RECV_ANY_SOURCE(CPU)
IMPL_RECV_ANY_SOURCE(CUDA)
IMPL_RECV_ANY_SOURCE(PrivateUse1)

#define IMPL_REDUCE(DEV)                                     \
  c10::intrusive_ptr<Work> reduce_##DEV(                     \
      at::TensorList tensors,                                \
      const c10::intrusive_ptr<ProcessGroup>& process_group, \
      const c10::intrusive_ptr<ReduceOp>& reduce_op,         \
      int64_t root_rank,                                     \
      int64_t root_tensor,                                   \
      bool asyncOp,                                          \
      int64_t timeout) {                                     \
    auto tensor_vec = tensors.vec();                         \
    return process_group->getBackend(c10::DeviceType::DEV)   \
        ->reduce(                                            \
            tensor_vec,                                      \
            ReduceOptions{                                   \
                *reduce_op.get(),                            \
                root_rank,                                   \
                root_tensor,                                 \
                std::chrono::milliseconds(timeout),          \
                asyncOp});                                   \
  }

IMPL_REDUCE(CPU)
IMPL_REDUCE(CUDA)
IMPL_REDUCE(PrivateUse1)

#define IMPL_BROADCAST(DEV)                                               \
  std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>           \
      broadcast_##DEV(                                                    \
          at::TensorList tensors,                                         \
          const c10::intrusive_ptr<ProcessGroup>& process_group,          \
          int64_t root_rank,                                              \
          int64_t root_tensor,                                            \
          bool asyncOp,                                                   \
          int64_t timeout) {                                              \
    auto tensor_vec = tensors.vec();                                      \
    auto work = process_group->getBackend(c10::DeviceType::DEV)           \
                    ->broadcast(                                          \
                        tensor_vec,                                       \
                        BroadcastOptions{                                 \
                            root_rank,                                    \
                            root_tensor,                                  \
                            std::chrono::milliseconds(timeout),           \
                            asyncOp});                                    \
    return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
        std::move(tensor_vec), work);                                     \
  }

IMPL_BROADCAST(CPU)
IMPL_BROADCAST(CUDA)
IMPL_BROADCAST(PrivateUse1)

// Return input tensors as output tensors to make inplace allreduce look like
// a functional API, so that make_fx can correctly build the dependencies in
// the graph later.
#define IMPL_ALLREDUCE(DEV)                                               \
  std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>           \
      allreduce_##DEV(                                                    \
          at::TensorList tensors,                                         \
          const c10::intrusive_ptr<ProcessGroup>& process_group,          \
          const c10::intrusive_ptr<ReduceOp>& reduce_op,                  \
          const std::optional<at::Tensor>& sparse_indices,                \
          bool asyncOp,                                                   \
          int64_t timeout) {                                              \
    auto tensor_vec = tensors.vec();                                      \
    auto work = process_group->getBackend(c10::DeviceType::DEV)           \
                    ->allreduce(                                          \
                        tensor_vec,                                       \
                        AllreduceOptions{                                 \
                            *reduce_op.get(),                             \
                            std::chrono::milliseconds(timeout),           \
                            asyncOp});                                    \
    return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
        std::move(tensor_vec), work);                                     \
  }

IMPL_ALLREDUCE(CPU)
IMPL_ALLREDUCE(CUDA)
IMPL_ALLREDUCE(PrivateUse1)

#define IMPL_ALLREDUCE_COALESCED(DEV)                             \
  c10::intrusive_ptr<Work> allreduce_coalesced_##DEV(             \
      at::TensorList tensors,                                     \
      const c10::intrusive_ptr<ProcessGroup>& process_group,      \
      const c10::intrusive_ptr<ReduceOp>& reduce_op,              \
      bool asyncOp,                                               \
      int64_t timeout) {                                          \
    auto tensor_vec = tensors.vec();                              \
    AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; \
    opts.reduceOp = *reduce_op.get();                             \
    opts.timeout = std::chrono::milliseconds(timeout);            \
    opts.asyncOp = asyncOp;                                       \
    return process_group->getBackend(c10::DeviceType::DEV)        \
        ->allreduce_coalesced(tensor_vec, opts);                  \
  }

IMPL_ALLREDUCE_COALESCED(CPU)
IMPL_ALLREDUCE_COALESCED(CUDA)
IMPL_ALLREDUCE_COALESCED(PrivateUse1)

// Copy output tensors (not storage) so that this can be used in a functional
// manner
#define IMPL_ALLGATHER(DEV)                                                    \
  std::tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>   \
      allgather_##DEV(                                                         \
          const std::vector<std::vector<at::Tensor>>& output_tensors,          \
          at::TensorList input_tensors,                                        \
          const c10::intrusive_ptr<ProcessGroup>& process_group,               \
          bool asyncOp,                                                        \
          int64_t timeout) {                                                   \
    auto input_tensors_vec = input_tensors.vec();                              \
    auto work = process_group->getBackend(c10::DeviceType::DEV)                \
                    ->allgather(                                               \
                        const_cast<std::vector<std::vector<at::Tensor>>&>(     \
                            output_tensors),                                   \
                        input_tensors_vec,                                     \
                        AllgatherOptions{                                      \
                            std::chrono::milliseconds(timeout), asyncOp});     \
    return std::                                                               \
        tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>( \
            output_tensors, work);                                             \
  }

// NOLINTBEGIN(cppcoreguidelines-pro-type-const-cast)
IMPL_ALLGATHER(CPU)
IMPL_ALLGATHER(CUDA)
IMPL_ALLGATHER(PrivateUse1)

#define IMPL__ALLGATHER_BASE(DEV)                                          \
  std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _allgather_base_##DEV(  \
      at::Tensor& output_tensor,                                           \
      at::Tensor& input_tensor,                                            \
      const c10::intrusive_ptr<ProcessGroup>& process_group,               \
      bool asyncOp,                                                        \
      int64_t timeout) {                                                   \
    auto work = process_group->getBackend(c10::DeviceType::DEV)            \
                    ->_allgather_base(                                     \
                        output_tensor,                                     \
                        input_tensor,                                      \
                        AllgatherOptions{                                  \
                            std::chrono::milliseconds(timeout), asyncOp}); \
    return std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(               \
        output_tensor, work);                                              \
  }

IMPL__ALLGATHER_BASE(CPU)
IMPL__ALLGATHER_BASE(CUDA)
IMPL__ALLGATHER_BASE(PrivateUse1)

#define IMPL_ALLGATHER_COALESCED(DEV)                                        \
  c10::intrusive_ptr<Work> allgather_coalesced_##DEV(                        \
      const std::vector<std::vector<at::Tensor>>& output_lists,              \
      const at::TensorList& input_list,                                      \
      const c10::intrusive_ptr<ProcessGroup>& process_group,                 \
      bool asyncOp) {                                                        \
    auto input_list_vec = input_list.vec();                                  \
    auto opts = AllgatherOptions{};                                          \
    opts.asyncOp = asyncOp;                                                  \
    return process_group->getBackend(c10::DeviceType::DEV)                   \
        ->allgather_coalesced(                                               \
            const_cast<std::vector<std::vector<at::Tensor>>&>(output_lists), \
            input_list_vec,                                                  \
            opts);                                                           \
  }

IMPL_ALLGATHER_COALESCED(CPU)
IMPL_ALLGATHER_COALESCED(CUDA)
IMPL_ALLGATHER_COALESCED(PrivateUse1)

#define IMPL_ALLGATHER_INTO_TENSOR_COALESCED(DEV)                       \
  c10::intrusive_ptr<c10d::Work> allgather_into_tensor_coalesced_##DEV( \
      at::TensorList outputs,                                           \
      at::TensorList inputs,                                            \
      const c10::intrusive_ptr<ProcessGroup>& process_group,            \
      bool asyncOp) {                                                   \
    auto output_vec = outputs.vec();                                    \
    auto input_vec = inputs.vec();                                      \
    auto opts = AllgatherOptions{};                                     \
    opts.asyncOp = asyncOp;                                             \
    return process_group->getBackend(c10::DeviceType::DEV)              \
        ->allgather_into_tensor_coalesced(output_vec, input_vec, opts); \
  }

IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CPU)
IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CUDA)
IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1)

#define IMPL_REDUCE_SCATTER(DEV)                                           \
  std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>            \
      reduce_scatter_##DEV(                                                \
          const at::TensorList& output_tensors,                            \
          const std::vector<std::vector<at::Tensor>>& input_tensors,       \
          const c10::intrusive_ptr<ProcessGroup>& process_group,           \
          const c10::intrusive_ptr<ReduceOp>& reduce_op,                   \
          bool asyncOp,                                                    \
          int64_t timeout) {                                               \
    auto output_tensors_vec = output_tensors.vec();                        \
    auto work = process_group->getBackend(c10::DeviceType::DEV)            \
                    ->reduce_scatter(                                      \
                        output_tensors_vec,                                \
                        const_cast<std::vector<std::vector<at::Tensor>>&>( \
                            input_tensors),                                \
                        ReduceScatterOptions{                              \
                            *reduce_op.get(),                              \
                            std::chrono::milliseconds(timeout),            \
                            asyncOp});                                     \
    return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(  \
        output_tensors_vec, work);                                         \
  }

IMPL_REDUCE_SCATTER(CPU)
IMPL_REDUCE_SCATTER(CUDA)
IMPL_REDUCE_SCATTER(PrivateUse1)

#define IMPL__REDUCE_SCATTER_BASE(DEV)                                         \
  std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _reduce_scatter_base_##DEV( \
      at::Tensor& output_tensor,                                               \
      at::Tensor& input_tensor,                                                \
      const c10::intrusive_ptr<ProcessGroup>& process_group,                   \
      const c10::intrusive_ptr<ReduceOp>& reduce_op,                           \
      bool asyncOp,                                                            \
      int64_t timeout) {                                                       \
    auto work = process_group->getBackend(c10::DeviceType::DEV)                \
                    ->_reduce_scatter_base(                                    \
                        output_tensor,                                         \
                        input_tensor,                                          \
                        ReduceScatterOptions{                                  \
                            *reduce_op.get(),                                  \
                            std::chrono::milliseconds(timeout),                \
                            asyncOp});                                         \
    return std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(                   \
        output_tensor, work);                                                  \
  }

IMPL__REDUCE_SCATTER_BASE(CPU)
IMPL__REDUCE_SCATTER_BASE(CUDA)
IMPL__REDUCE_SCATTER_BASE(PrivateUse1)

#define IMPL_REDUCE_SCATTER_TENSOR_COALESCED(DEV)                       \
  c10::intrusive_ptr<c10d::Work> reduce_scatter_tensor_coalesced_##DEV( \
      at::TensorList outputs,                                           \
      at::TensorList inputs,                                            \
      const c10::intrusive_ptr<ProcessGroup>& process_group,            \
      const c10::intrusive_ptr<ReduceOp>& reduce_op,                    \
      bool asyncOp,                                                     \
      int64_t timeout) {                                                \
    auto output_vec = outputs.vec();                                    \
    auto input_vec = inputs.vec();                                      \
    return process_group->getBackend(c10::DeviceType::DEV)              \
        ->reduce_scatter_tensor_coalesced(                              \
            output_vec,                                                 \
            input_vec,                                                  \
            ReduceScatterOptions{                                       \
                *reduce_op.get(),                                       \
                std::chrono::milliseconds(timeout),                     \
                asyncOp});                                              \
  }

IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CPU)
IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CUDA)
IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1)

#define IMPL_GATHER(DEV)                                                       \
  c10::intrusive_ptr<Work> gather_##DEV(                                       \
      const std::vector<std::vector<at::Tensor>>& output_tensors,              \
      const at::TensorList& input_tensors,                                     \
      const c10::intrusive_ptr<ProcessGroup>& process_group,                   \
      int64_t root_rank,                                                       \
      bool asyncOp,                                                            \
      int64_t timeout) {                                                       \
    auto input_tensors_vec = input_tensors.vec();                              \
    return process_group->getBackend(c10::DeviceType::DEV)                     \
        ->gather(                                                              \
            const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), \
            input_tensors_vec,                                                 \
            GatherOptions{                                                     \
                root_rank, std::chrono::milliseconds(timeout), asyncOp});      \
  }

IMPL_GATHER(CPU)
IMPL_GATHER(CUDA)
IMPL_GATHER(PrivateUse1)

#define IMPL_SCATTER(DEV)                                                      \
  std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> scatter_##DEV( \
      const at::TensorList& output_tensors,                                    \
      const std::vector<std::vector<at::Tensor>>& input_tensors,               \
      const c10::intrusive_ptr<ProcessGroup>& process_group,                   \
      int64_t root_rank,                                                       \
      bool asyncOp,                                                            \
      int64_t timeout) {                                                       \
    auto output_tensors_vec = output_tensors.vec();                            \
    auto work =                                                                \
        process_group->getBackend(c10::DeviceType::DEV)                        \
            ->scatter(                                                         \
                output_tensors_vec,                                            \
                const_cast<std::vector<std::vector<at::Tensor>>&>(             \
                    input_tensors),                                            \
                ScatterOptions{                                                \
                    root_rank, std::chrono::milliseconds(timeout), asyncOp});  \
    return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(      \
        std::move(output_tensors_vec), work);                                  \
  }

IMPL_SCATTER(CPU)
IMPL_SCATTER(CUDA)
IMPL_SCATTER(PrivateUse1)

#define IMPL_ALLTOALL(DEV)                                                     \
  std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>                \
      alltoall_##DEV(                                                          \
          const at::TensorList& output_tensors,                                \
          const at::TensorList& input_tensors,                                 \
          const c10::intrusive_ptr<ProcessGroup>& process_group,               \
          bool asyncOp,                                                        \
          int64_t timeout) {                                                   \
    auto output_tensors_vec = output_tensors.vec();                            \
    auto input_tensors_vec = input_tensors.vec();                              \
    auto work =                                                                \
        process_group->getBackend(c10::DeviceType::DEV)                        \
            ->alltoall(                                                        \
                output_tensors_vec,                                            \
                input_tensors_vec,                                             \
                AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); \
    return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(      \
        std::move(output_tensors_vec), work);                                  \
  }

IMPL_ALLTOALL(CPU)
IMPL_ALLTOALL(CUDA)
IMPL_ALLTOALL(PrivateUse1)

#define IMPL_ALLTOALL_BASE(DEV)                                            \
  c10::intrusive_ptr<Work> alltoall_base_##DEV(                            \
      at::Tensor& output,                                                  \
      at::Tensor& input,                                                   \
      const c10::intrusive_ptr<ProcessGroup>& process_group,               \
      std::vector<int64_t> output_split_sizes,                             \
      std::vector<int64_t> input_split_sizes,                              \
      bool asyncOp,                                                        \
      int64_t timeout) {                                                   \
    return process_group->getBackend(c10::DeviceType::DEV)                 \
        ->alltoall_base(                                                   \
            output,                                                        \
            input,                                                         \
            output_split_sizes,                                            \
            input_split_sizes,                                             \
            AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); \
  }

IMPL_ALLTOALL_BASE(CPU)
IMPL_ALLTOALL_BASE(CUDA)
IMPL_ALLTOALL_BASE(PrivateUse1)

// NOLINTBEGIN(performance-unnecessary-value-param)
#define IMPL_BARRIER(DEV)                                                  \
  c10::intrusive_ptr<Work> barrier##DEV(                                   \
      at::Tensor /* unused */,                                             \
      const c10::intrusive_ptr<ProcessGroup>& process_group,               \
      const std::vector<int64_t>& device_ids,                              \
      bool asyncOp,                                                        \
      int64_t timeout) {                                                   \
    auto opts = BarrierOptions{};                                          \
    opts.device_ids = device_ids;                                          \
    opts.timeout = std::chrono::milliseconds(timeout);                     \
    opts.asyncOp = asyncOp;                                                \
    return process_group->getBackend(c10::DeviceType::DEV)->barrier(opts); \
  }

IMPL_BARRIER(CPU)
IMPL_BARRIER(CUDA)
IMPL_BARRIER(PrivateUse1)
// NOLINTEND(performance-unnecessary-value-param)
// NOLINTEND(cppcoreguidelines-pro-type-const-cast)

void monitored_barrier_CPU(
    // NOLINTNEXTLINE(performance-unnecessary-value-param)
    at::Tensor /* unused */,
    const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group,
    const std::vector<int64_t>& device_ids,
    int64_t timeout,
    bool wait_all_ranks) {
  process_group->getBackend(c10::DeviceType::CPU)
      ->monitoredBarrier(
          BarrierOptions{device_ids, std::chrono::milliseconds(timeout)},
          wait_all_ranks);
}

std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>
allreduce_sparse_cuda_(
    at::TensorList tensors,
    const c10::intrusive_ptr<ProcessGroup>& process_group,
    const c10::intrusive_ptr<ReduceOp>& reduce_op,
    const std::optional<at::Tensor>& sparse_indices,
    bool asyncOp,
    int64_t timeout) {
  auto tensor_vec = tensors.vec();
  auto work = process_group->getBackend(c10::DeviceType::CUDA)
                  ->allreduce_sparse(
                      tensor_vec,
                      AllreduceOptions{
                          *reduce_op,
                          std::chrono::milliseconds(timeout),
                          asyncOp,
                          sparse_indices});

  return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
      std::move(tensor_vec), work);
}
} // namespace

// register functions to dispatcher
namespace {

// 2nd level expansion
// FUNC: op name
// DEV: device
#define REGISTER_C10D_OP1(FUNC, DEV) \
  TORCH_LIBRARY_IMPL(c10d, DEV, m) { \
    m.impl(#FUNC, FUNC##DEV);        \
  }

// 1st level expansion
#define REGISTER_C10D_OP(FUNC)  \
  REGISTER_C10D_OP1(FUNC, CPU)  \
  REGISTER_C10D_OP1(FUNC, CUDA) \
  REGISTER_C10D_OP1(FUNC, PrivateUse1)

// Now we start to register ops with the three device keys

REGISTER_C10D_OP(send)
REGISTER_C10D_OP(recv_)
REGISTER_C10D_OP(recv_any_source_)
REGISTER_C10D_OP(reduce_)
REGISTER_C10D_OP(broadcast_)
REGISTER_C10D_OP(allreduce_)
REGISTER_C10D_OP(allreduce_coalesced_)
REGISTER_C10D_OP(allgather_)
REGISTER_C10D_OP(_allgather_base_)
REGISTER_C10D_OP(allgather_coalesced_)
REGISTER_C10D_OP(allgather_into_tensor_coalesced_)
REGISTER_C10D_OP(reduce_scatter_)
REGISTER_C10D_OP(_reduce_scatter_base_)
REGISTER_C10D_OP(reduce_scatter_tensor_coalesced_)
REGISTER_C10D_OP(gather_)
REGISTER_C10D_OP(scatter_)
REGISTER_C10D_OP(alltoall_)
REGISTER_C10D_OP(alltoall_base_)
REGISTER_C10D_OP(barrier)

// The following ops are specialized, register them separately

TORCH_LIBRARY_IMPL(c10d, CPU, m) {
  m.impl("monitored_barrier_", monitored_barrier_CPU);
}

// TODO: The SparseCPU/SparseCUDA dispatched methods are only used to support
// sparse all_reduce in the Gloo backend
TORCH_LIBRARY_IMPL(c10d, SparseCPU, m) {
  m.impl("allreduce_", allreduce_CPU);
}

TORCH_LIBRARY_IMPL(c10d, SparseCUDA, m) {
  m.impl("allreduce_", allreduce_sparse_cuda_);
}

} // namespace

} // namespace ops
} // namespace c10d
