#pragma once #include #include #include #include #include #include #include #include namespace torch::autograd { struct TORCH_CUDA_CU_API Scatter : public Node { explicit Scatter( std::vector devices, std::optional> chunk_sizes = std::nullopt, int64_t dim = 0, std::optional>> streams = std::nullopt, bool unsqueeze_scalars = false); ~Scatter() override; variable_list apply(variable_list&& inputs) override; std::vector devices_; std::optional> chunk_sizes_; int64_t dim_; std::optional>> streams_; bool unsqueeze_scalars_; }; struct TORCH_CUDA_CU_API Gather : public Node { explicit Gather(const at::Device& destination_device, int64_t dim = 0); ~Gather() override; variable_list apply(variable_list&& inputs) override; at::Device destination_device_; int64_t dim_; }; } // namespace torch::autograd