#pragma once #include #include #include #include #include #include #include #include namespace torch::jit::tensorexpr { using VarMapping = std::vector>; class VarSubMutator : public IRMutator { public: VarSubMutator(const VarMapping& var_mapping) { for (auto& entry : var_mapping) { VarPtr key_var = entry.first; ExprPtr value = entry.second; if (!key_var) { throw malformed_input("missing key in VarSubMutator"); } var_mapping_[std::move(key_var)] = std::move(value); } } ExprPtr mutate(const VarPtr& var) override { auto iter = var_mapping_.find(var); if (iter == var_mapping_.end()) { return var; } return iter->second; } ExprPtr mutate(const ReduceOpPtr& var) override { auto body = var->body()->accept_mutator(this); std::vector new_inner; for (const auto& v : var->reduce_args()) { ExprPtr e = v->accept_mutator(this); if (VarPtr new_var = to(e)) { new_inner.push_back(std::move(new_var)); } else { VarFinder varFinder; e->accept(&varFinder); auto varlist = varFinder.vars(); new_inner.insert(new_inner.end(), varlist.begin(), varlist.end()); } } return alloc(body, new_inner, var->reducer()); } private: std::unordered_map var_mapping_; }; } // namespace torch::jit::tensorexpr