#pragma once #include #include namespace torch::jit::fuser::onednn { struct WorkBlock : public std::pair { using pair::pair; Node* begin() { return this->first; } Node* end() { return this->second; } }; class GraphRewriter { public: GraphRewriter(Block* block, std::shared_ptr graph, AliasDb& aliasDb) : block_(block), graph_(std::move(graph)), aliasDb_(aliasDb), llgaHelper_(graph_) {} void cleanupSubgraphs(); void buildupSubgraphs(); private: Block* block_; std::shared_ptr graph_; AliasDb& aliasDb_; LlgaGraphHelper llgaHelper_; std::vector buildWorkBlocks(); std::pair scanNode( Node* consumer, graph_node_list::iterator workblock_begin); std::optional tryMerge(Node* consumer, Node* producer); }; // This pass creates the subgraphs for oneDNN Graph Fusion Nodes. // Its code-structure has been vastly inspired from // torch/csrc/jit/passes/create_autodiff_subgraphs.cpp void CreateLlgaSubgraphs(std::shared_ptr& graph); } // namespace torch::jit::fuser::onednn