#pragma once #include #include #include #include #include namespace torch::jit { struct TORCH_API MutationRemover { MutationRemover( std::shared_ptr graph, std::optional> mutation_filter = std::nullopt) : mutation_filter_(std::move(mutation_filter)), aliasDb_(nullptr), graph_(std::move(graph)) {} // return true if graph is modified bool removeListMutation(); // return true if graph is modified bool removeTensorMutation(); bool isSpecialMappedOp(Node* n) { return n->matches("aten::zero_(Tensor(a!) self) -> Tensor(a!)") || n->matches( "aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)") || n->matches( "aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)"); } bool inplaceOpVariant(Node* n); static bool hasSideEffectOrAlias(Value* v, AliasDb* aliasDb); private: Node* createSpecialMappedOp(Node* n); bool listMutationFollowingListConstruct(Node* n); bool tryMakeCreationAndMutationAtomic( Value* mutated_value, Node* mutating_op); bool tryMakeUnaliasedIfOutputAndMutationAtomic( Value* mutated_value, Node* mutating_op); // return true if graph is modified bool RemoveListMutation(Block* block); // return true if graph is modified bool RemoveTensorMutation(Block* block); AliasDb* getOrCreateAliasDb() { if (!aliasDb_) { aliasDb_ = std::make_unique(graph_); } return aliasDb_.get(); } std::optional> mutation_filter_; std::unique_ptr aliasDb_ = nullptr; std::shared_ptr graph_; }; // Removes list mutation with functional equivalents // return true if graph is modified TORCH_API bool RemoveListMutation(const std::shared_ptr& graph); // Replaces in-place aten ops with their functional equivalents // when it can be proven that this does not change graph semantics // if `mutation_filter` is present, the pass will only attempt to // remove mutation on nodes which return true for the filter // return true if graph is modified TORCH_API bool RemoveTensorMutation( const std::shared_ptr& graph, std::optional> mutation_filter = std::nullopt); // Replaces in-place aten activation ops with their functional equivalence TORCH_API bool InplaceToFunctionalActivation( const std::shared_ptr& graph); } // namespace torch::jit