#pragma once #include #include #include #include #include #include namespace torch::nativert { using WeightVersion = int; /** * @brief A class that manages the weights of a graph, providing functionality * to load, access, and manipulate them. * * It is responsible for handling the parameters, buffers, and constants * associated with a graph It provides mechanisms to load weights from * serialized data, access and modify them, and performs necessary validation * checks. */ class Weights { public: explicit Weights( const Graph* graph, const std::optional>& stateDict = std::nullopt, Placement placement = Placement()); // Arguments // - pytorchStreamReader: the reader for the model archive // - stateDictPath: a map from parameter/buffer/constant name to file path in // the archive // - stateDictPathPrefix: a prefix that will be prepended to paths in // stateDictPathPrefix // - constantPaths: a map from constant name to file path in the archive // - constantPathPrefix: a prefix that will be prepended to paths in // constantPathPrefix // - placement: the device placement of the weights, default to follow the // original device in the weight's metadata explicit Weights( const Graph* graph, std::shared_ptr pytorchStreamReader, const std::unordered_map& stateDictPaths, std::string_view stateDictPathPrefix, const std::unordered_map& constantPaths, std::string_view constantPathPrefix, Placement placement = Placement(), std::function skipSizeCheck = {}, std::function skipDtypeCheck = {}); at::Tensor at(const std::string& name) const; at::Tensor& at(const std::string& name); bool contains(const std::string& name) const; c10::IValue getCustomObj(const std::string& name) const; c10::IValue getCustomObjByFileName(const std::string& name) const; std::unordered_map parameters() const; std::unordered_map buffers() const; std::unordered_map attributes() const; void loadStateDict( const std::unordered_map& stateDict); /* * Replace the value stored at the weight with name "name". */ void setValue(const std::string& name, const at::Tensor& newValue); /* * Update the value stored at the weight with name "name". * This is done in-place. */ void updateValue(const std::string& name, const at::Tensor& newValue); void updateValues( const std::unordered_map& newValues); void validateValue(const std::string& name, const at::Tensor& newValue) const; void validateAllWeightsLoaded(); void updateFoldedConst(std::string_view name, c10::IValue tensor); const std::unordered_map& getFoldedConsts() const; C10_ALWAYS_INLINE const c10::FastMap& getConstFoldedValues() const { return constFoldedValues_; } C10_ALWAYS_INLINE void setConstFoldedValue( const std::string& n, c10::IValue iv) { constFoldedValues_.insert_or_assign(n, std::move(iv)); } std::string toString() const; WeightVersion version() const { return version_; } private: const Graph* graph_; const std::unordered_map& weightsMeta_; Placement placement_; // keys are parameter/buffer/constant names, not graph input names! std::unordered_map allValues_; std::unordered_map customObjs_; // contains CustomClassHolder map from a file name to an arbitray // key in customObjs_ that hold the loaded content of the file. // This is used in AOTIDelegateExecutor. std::unordered_map customObjsPaths_; // The liftcycle of folded consts should be tied with the weights from which // it was derived. The ordering of the constant should be consistent with // the output order of const graph. std::vector foldedConsts_; std::unordered_map foldedConstsMap_; c10::FastMap constFoldedValues_; // unique version number for this instance of weight const WeightVersion version_; // every instance of Weight has a unique version number static WeightVersion globalVersion_; std::function skipSizeCheck_ = {}; std::function skipDtypeCheck_ = {}; // save the names of unused weights std::unordered_set unusedWeights_; }; } // namespace torch::nativert