#pragma once #include #include #include #include namespace torch::distributed::autograd { // Used to propagate gradients from one node to another during a distributed // backwards pass. This RPC call is invoked when we hit a `recv` autograd // function during backward pass execution. class TORCH_API PropagateGradientsReq : public rpc::RpcCommandBase { public: PropagateGradientsReq( const AutogradMetadata& autogradMetadata, std::vector grads, bool retainGraph = false); const AutogradMetadata& getAutogradMetadata(); const std::vector& getGrads(); // Serialization and deserialization methods. c10::intrusive_ptr toMessageImpl() && override; static std::unique_ptr fromMessage( const rpc::Message& message); // Whether or not to retain the autograd graph. bool retainGraph(); private: AutogradMetadata autogradMetadata_; std::vector grads_; bool retainGraph_; }; } // namespace torch::distributed::autograd