#pragma once #include #include #include namespace torch::autograd::utils { // Warning handler for multi-threaded contexts. Gather warnings from // all threads into a single queue, then process together at the end // in the main thread. class DelayWarningHandler : public at::WarningHandler { public: ~DelayWarningHandler() override = default; void replay_warnings(); private: void process(const c10::Warning& warning) override; std::vector warnings_; std::mutex mutex_; }; } // namespace torch::autograd::utils