#pragma once #include #include #include #include #include #include #include #include #include #include #include #include // Uses a compressed index representation for faster comparisons typedef c10::SparseBitVector<256> MemoryLocations; namespace torch::jit { struct Value; using AliasTypeSet = std::vector; // `Element` represents a vertex in the points-to graph. It represents // anything that could have an aliasing relationship--mostly IR // `Value`s, but also wildcards or the type inside a container (e.g. `T` // in `List[T]`) struct Element { Element(const Value* value_, unsigned index_); // wildcard constructor explicit Element(unsigned index_); // Index into the owning DAG's bit vector that represents this element. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) unsigned index; // All elements that this element *may* point to. It's possible to have // multiple elements that you might point to due to control flow/complex ops // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) MemoryLocations pointsTo; // Backreference for points-to. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) MemoryLocations pointedFrom; // Elements can contain other elements (e.g. List[Tensor]) // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) MemoryLocations containedElements; // The values that this element corresponds to. May be empty if this element // doesn't represent a first-class value. // This is for debug information only. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::unordered_set values; private: // Make `from` point at `to`. void makePointerTo(Element* from, Element* to); friend class MemoryDAG; // We memoize the results of `getMemoryLocations` to speed up queries. // A nullopt means that this cache is not yet populated. Since `MemoryDAG` is // immutable, this cache should never need to be invalidated. mutable std::optional cachedMemoryLocations_; mutable std::optional cachedAllContainedMemoryLocations_; }; // class MemoryDAG // // This class tracks the "A points to B" graph for all values. It is used by // AliasDb to provide a higher-level API. // // We maintain a DAG where: // - Vertices (called "Elements") represent Values and // other aliasing entities (e.g. the stuff inside a list) // - Edges represent a "points-to" relationship. // // Leaves in this DAG are entities that don't point to anything, and thus // correspond to unique "memory locations". // // So, by traversing the "points-to" graph to the leaves, you can determine // which memory locations an element may point to. class TORCH_API MemoryDAG { public: explicit MemoryDAG(std::vector> indexToElementMap) : indexToElementMap_(std::move(indexToElementMap)) {} // explicitly delete copy constructor because otherwise windows build is // confused for an exported class see // https://stackoverflow.com/a/51033485/105137 MemoryDAG(const MemoryDAG&) = delete; MemoryDAG& operator=(const MemoryDAG&) = delete; // Return the unique memory locations that `Element` might represent. const MemoryLocations& getMemoryLocations(const Element* e) const; // Do `a` and `b` potentially share a memory location? bool mayAlias(const Element* a, const Element* b) const; // Does `a` hold reference to any memory that is stored in `b`, or vice versa? bool mayContainAlias(const Element* a, const Element* b) const; bool mayContainAlias(const Element* a, const at::ArrayRef b) const; bool mayContainAlias( const at::ArrayRef a, const at::ArrayRef b) const; // Converts from the compressed index representation const Element* fromIndex(unsigned x) const; Element* fromIndex(unsigned x); void collectAllContainedMemoryLocations( const Element* elem, MemoryLocations& cont) const; /** * The following methods are special cases where we need to mutate the * internals of MemoryDAG for efficiency reasons. Don't call them unless you * know what you're doing! In particular, don't add new mutating methods * without ensuring that you are maintaining cache consistency for memory * locations. */ // Adding wildcards can trigger extremely expensive cache invalidations. This // method adds them in a more efficient cache-aware way. void setWildcards( const std::unordered_set& wildcards, const ska::flat_hash_map& elementMap, const std::function& getWildcardElement); Element* unsafeMakeFreshValue(const Value* v); private: const MemoryLocations& getAllContainedMemoryLocations( const Element* elem) const; void collectAllContainedMemoryLocationsImpl( const Element* elem, MemoryLocations& cont) const; std::vector> indexToElementMap_; }; /** * Helper to build up the points-to graph. * * We separate the "building" into a different class because it allows us to * cache internally to MemoryDAG without worrying about how the DAG structure * is mutated. */ class TORCH_API MemoryDAGBuilder { public: MemoryDAGBuilder() = default; MemoryDAGBuilder(const MemoryDAGBuilder&) = delete; MemoryDAGBuilder& operator=(const MemoryDAGBuilder&) = delete; // Make `from` point at `to`. void makePointerTo(Element* from, Element* to); void addToContainedElements(Element* contained, Element* container); std::unique_ptr createMemoryDAG() && { return std::make_unique(std::move(indexToElementMap_)); } // Make a fresh Element (i.e. an Element that doesn't point to anything) and // return it. Element* makeFreshValue(const Value* v); friend MemoryDAG; private: // `MemoryDAGBuilder` builds up `indexToElementMap_`, then uses // the map to construct the `MemoryDAG` std::vector> indexToElementMap_; }; } // namespace torch::jit