#pragma once #include #include #include #include #include #include #include #include #include namespace torch::jit::tensorexpr::analysis { enum class AccessType { Input, Output, Load, Store, Call, AtomicAdd, Alloc, Free }; const char* AccessToString(AccessType a); class AccessInfo; using DependencySet = std::unordered_set>; /* AccessInfo * * Represents a single bounded memory access to a buffer, for instance a Load or * a Store. Holds information relating to the specific access and links to * connected accesses in the dependency graph. */ class TORCH_API AccessInfo { public: AccessInfo( size_t id, AccessType type, StmtPtr stmt, VarPtr var, IndexBounds bounds) : id_(id), type_(type), stmt_(std::move(stmt)), expr_(nullptr), var_(std::move(var)), bounds_(std::move(bounds)) {} AccessInfo( size_t id, AccessType type, ExprPtr expr, StmtPtr stmt, VarPtr var, IndexBounds bounds) : id_(id), type_(type), stmt_(std::move(stmt)), expr_(std::move(expr)), var_(std::move(var)), bounds_(std::move(bounds)) {} // Id is a unique int representing the order this access occurred in the // graph. size_t id() const { return id_; } // The type of the access (Load, Store, etc). AccessType type() const { return type_; } // The enclosing Stmt this access represents. E.g. if this is a Store then // Stmt is the Store itself, while if the access is caused by an Expr, this is // the most immediate parent Stmt. StmtPtr stmt() const { return stmt_; } // If the access is represented by an Expr (such as Load or Call) then this is // it, otherwise it's nullptr. ExprPtr expr() const { return expr_; } // The Var representing the underlying Buffer. VarPtr var() const { return var_; } // A vector of Bounds representing the start and end expression for each // dimension. IndexBounds& bounds() { return bounds_; } // Each access that this depends upon, // eg. if this is a Load, then it contains every Store that immediately // contributes to a load of the bounds. // or: if this is a Store, it contains all reads on the RHS of the Store. const std::map>& dependencies() const { return dependencies_; } // Each access that depends on this one. // ie. this access is present in the dependencies map of all accesses that are // dependent. std::map> dependents() const { std::map> res; for (const auto& kv : dependents_) { res.emplace(kv.first, kv.second.lock()); } return res; } // Returns the symbolic expression of the indices of this access. std::vector getIndices() const; // Establishes a dependency or dependent relationship with another access. void addDependency(const std::shared_ptr& write); void addDependent(const std::shared_ptr& read); // helper for checking dependencies. bool hasDependency(const std::shared_ptr& info) const; // Returns the set of all nodes that are direct (immediate) dependencies of // this access. DependencySet getDirectDependencies(); // likewise, returns all nodes that directly depend on this one. DependencySet getDirectDependents(); // Returns the full list of all nodes in the graph that this access depends // on, and all nodes they depend on, and so forth, back to the inputs. DependencySet getIndirectDependencies(); // likewise, returns the full list of all nodes that depend on this node, and // all nodes that depend on those nodes and so on down to the outputs. DependencySet getIndirectDependents(); // Does this access represent a read of memory (Load, ReduceOp, Call, etc). bool isRead() const; // Does this access represent a write of memory (Store, etc). bool isWrite() const; // Helpers for dumping accesses in various formats. void print() const; void dumpDOT(std::ostream& os) const; const char* AccessTypeColour() const; private: size_t id_; AccessType type_; StmtPtr stmt_; ExprPtr expr_; VarPtr var_; IndexBounds bounds_; // Yes these should be sorted. std::map> dependencies_; std::map> dependents_; }; using VarBoundMap = std::unordered_map; /* MemDependencyChecker analyses a IR fragment and builds a dependency graph of * accesses contained within. * * It's possible to retrieve the entire graph in node-object form, or can be * used as an oracle for answering dependency questions. e.g: * * analyzer.hasIndirectDependency(BufA, BufB); or, * analyzer.hasDirectDependency(LoadA, StoreB); */ class TORCH_API MemDependencyChecker : public IRVisitor { struct Scope; public: MemDependencyChecker(); MemDependencyChecker( const std::unordered_set& inputs, const std::unordered_set& outputs); MemDependencyChecker( const std::vector& inputs, const std::vector& outputs); ~MemDependencyChecker() override = default; // Whether or not to allow loop execution order to influence dependency // calculation. If the loop may later be parallelized you don't want this. bool allowLoopExecutionOrderAnalysis(bool allow = true); // Dependency Checking API. // The goal is to have enough overloads here so you don't really have to think // about it. // Returns true if any read in A has a direct dependence on a write in B. bool dependsDirectly(const StmtPtr& A, const StmtPtr& B); bool dependsDirectly(const ExprPtr& A, const StmtPtr& B); // Returns true of the output depends directly on a write contained in B. bool dependsDirectly(const BufPtr& output, const StmtPtr& B); // Returns true if a read in A depends directly on the provided input. bool dependsDirectly(const StmtPtr& A, const BufPtr& input); bool dependsDirectly(const ExprPtr& A, const BufPtr& input); // Outputs/inputs cannot depend directly. // Returns true if the access A has B as an immediate dependency. bool dependsDirectly( const std::shared_ptr& A, const std::shared_ptr& B); // Returns true if any read in A has an ancestor write contained in B. bool dependsIndirectly(const StmtPtr& A, const StmtPtr& B); bool dependsIndirectly(const ExprPtr& A, const StmtPtr& B); // Returns true of the output depends indirectly on a write contained in B. bool dependsIndirectly(const BufPtr& output, const StmtPtr& B); // Returns true if a read in A depends indirectly on the provided input. bool dependsIndirectly(const StmtPtr& A, const BufPtr& input); bool dependsIndirectly(const ExprPtr& A, const BufPtr& input); // returns true if the output uses any load of the input. bool dependsIndirectly(const BufPtr& output, const BufPtr& input); // Returns true if the access A has a dependency chain to access B. bool dependsIndirectly( const std::shared_ptr& A, const std::shared_ptr& B); // Returns the AccessInfo std::shared_ptr accessFor(const StmtPtr& A) const; std::shared_ptr accessFor(const ExprPtr& A) const; // Returns all AccessInfos. std::unordered_set> accessesWithin( const StmtPtr& A) const; // TODO: this will return only the AccessInfo for A. It's included for // completeness but be aware it wont return accesses used in the computation // of A. std::unordered_set> accessesWithin( const ExprPtr& A) const; // Accesses relating to input and output buffers. std::shared_ptr input(const BufPtr& B) const; std::shared_ptr output(const BufPtr& B) const; // Returns the full history of reads and writes. const std::vector>& getHistory() const; // Dumps the dependency graph in DOT format. void dumpDAG(const std::string& filename) const; private: // Node visitors. void visit(const StorePtr& v) override; void visit(const LoadPtr& v) override; void visit(const ForPtr& v) override; void visit(const CondPtr& v) override; void visit(const IfThenElsePtr& v) override; void visit(const CompareSelectPtr& v) override; void visit(const BlockPtr& v) override; void visit(const LetPtr& v) override; void visit(const AtomicAddPtr& v) override; void visit(const AllocatePtr& v) override; void visit(const FreePtr& v) override; using BoundRelationship = std::pair>; // An internal struct holding the accesses found within a scope Block. struct Scope { Scope(BlockPtr b, std::shared_ptr p) : block(std::move(b)), parent(std::move(p)) {} BlockPtr block; std::shared_ptr parent; std::unordered_map shadowedVarBounds; std::unordered_set localVars; std::vector> accesses_; std::unordered_map> openWrites_; }; std::shared_ptr currentScope_; bool allowExecutionOrderAnalysis_{false}; std::unordered_multimap> stmtToAccess_; std::unordered_multimap> exprToAccess_; std::unordered_map>> scopeToAccesses_; VarBoundMap knownVarBounds_; // Finds all accesses that are reads within the scope of v. template DependencySet getAllReadsWithin(const StmtOrExprPtr& v) { DependencySet reads; auto insertAllReads = [&](const auto& nodes) { for (const auto& l : nodes) { auto bound = exprToAccess_.equal_range(l); for (auto it = bound.first; it != bound.second; ++it) { if (it->second->isRead()) { reads.insert(it->second); } } } }; // Look for and insert accesses belonging to all nodes that act like // reads. insertAllReads(NodeFinder::find(v)); insertAllReads(NodeFinder::find(v)); return reads; } // Finds all accesses that are writes within the scope of v. // Writes cannot occur in Exprs, so this is a little simpler. DependencySet getAllWritesWithin(const StmtPtr& v) { DependencySet writes; // writes just Store currently. auto stores = NodeFinder::find(v); for (const auto& s : stores) { auto bound = stmtToAccess_.equal_range(s); for (auto it = bound.first; it != bound.second; ++it) { if (it->second->isWrite()) { writes.insert(it->second); } } } return writes; } // Templated helpers to work on either Exprs or Stmts. template bool dependsDirectlyHelper(const StmtOrExprPtr& A, const StmtPtr& B) { auto aReads = getAllReadsWithin(A); auto bWrites = getAllWritesWithin(B); for (auto& read : aReads) { for (auto& depPair : read->dependencies()) { if (bWrites.count(depPair.second) != 0) { return true; } } } return false; } template bool dependsIndirectlyHelper(StmtOrExprPtr A, const StmtPtr& B) { auto aReads = getAllReadsWithin(A); auto bWrites = getAllWritesWithin(B); auto aDeps = getAllWriteDependencies(aReads); for (auto& dependency : aDeps) { if (bWrites.count(dependency) != 0) { return true; } } return false; } DependencySet getAllWriteDependencies(const DependencySet& products); // Maps for inputs and outputs, since they aren't present directly in the IR. std::unordered_map> inputs_; std::unordered_map> outputs_; std::unordered_map> intermediates_; // Inserts accesses for Buf's: specifically for inputs and outputs. void insertBuffers( std::unordered_map>& bufs, AccessType type); // Update the write history with a new write, adding dependencies and closing // any overlapped writes (if possible). void updateWriteHistory( std::list& writeHistory, const std::shared_ptr& info, size_t latestAccessToClose, bool closeOverlapped = true, bool insert = true); // Merge a child scope into a parent scope, adding dependencies for open // writes in the parent to accesses in the child. void mergeScope( const std::shared_ptr& child, const std::shared_ptr& parent, bool closeOverlapped = true); // Binds symbolic vars in indices with the low and high bound for those vars. std::vector getIndicesBounds(const std::vector& indices); size_t nextAccess_{0}; StmtPtr lastStmt_{nullptr}; }; } // namespace torch::jit::tensorexpr::analysis