#pragma once #include namespace torch::jit { // return true if graph is modified TORCH_API bool UnrollLoops(std::shared_ptr& graph); // Only unrolls constant loops. Will unroll them regardless of loop block size TORCH_API bool UnrollConstantLoops(std::shared_ptr& graph); TORCH_API Node* PeelLoop(Node* n, size_t times); // return true if graph is modified TORCH_API bool PeelProfilingLoops(const std::shared_ptr& graph); struct TORCH_API LoopsPeeler { LoopsPeeler(std::function callback, size_t num_iterations = 1) : callback_(std::move(callback)), num_iterations_(num_iterations) {} bool run(const std::shared_ptr& graph); private: void collectLoop(Node* n); void collectLoops(Block* block); void peelLoops(); std::function callback_ = nullptr; Node* in_loop_ = nullptr; std::list loops_to_peel_; size_t num_iterations_ = 1; }; } // namespace torch::jit