/** * Hash utils in this file is adapted from PyTorch/XLA * https://github.com/pytorch/xla/blob/e0e5f937a0ba8d904f9608137dc8c51ba439df2d/third_party/xla_client/util.h */ #pragma once #include #include #include #include #include #include #include #include #include namespace torch::lazy { using size_t = std::size_t; class TORCH_API hash_t : public c10::uint128 { public: // Switch from typedef hash_t = uint128 to provide explicit casters hash_t(int8_t val) : uint128(static_cast(val)) {} hash_t(int16_t val) : uint128(static_cast(val)) {} hash_t(int32_t val) : uint128(static_cast(val)) {} hash_t(int64_t val) : uint128(static_cast(val)) {} hash_t(uint32_t val) : uint128(val) {} hash_t(uint64_t val) : uint128(val) {} hash_t(uint128 val) : uint128(val) {} hash_t(uint64_t top, uint64_t bottom) : uint128(top, bottom) {} hash_t() = default; }; // Std* functions use 64-bit hash size_t TORCH_API StdDataHash(const void* data, size_t size); size_t TORCH_API StdHashCombine(uintmax_t a, uintmax_t b); // Other functions are all 128-bit hash_t TORCH_API HashBlock(const void* data, size_t n, const hash_t& seed); hash_t TORCH_API DataHash(const void* data, size_t size); hash_t TORCH_API HashCombine(const hash_t& a, const hash_t& b); size_t TORCH_API HashReduce(const hash_t& a); // Returns a string representation of a hash std::string TORCH_API HashToString(const hash_t& a); struct HashReducer { size_t operator()(const hash_t& value) const { return HashReduce(value); } }; static inline hash_t StringHash(const char* data) { return DataHash(data, std::strlen(data)); } // Automatic templated implementation for 'arithmetic' types template >* = nullptr> hash_t Hash(const T& value) { return DataHash(&value, sizeof(value)); } // added because on macos builds the vector specialization // breaks falling through to the templated arithmetic types above hash_t TORCH_API Hash(const std::vector& value); // Specialized implementations for proprietary types static inline hash_t Hash(const c10::ScalarType& value) { return DataHash(&value, sizeof(value)); } static inline hash_t Hash(const c10::MemoryFormat& value) { return DataHash(&value, sizeof(value)); } static inline hash_t Hash(const c10::DeviceType& value) { return DataHash(&value, sizeof(value)); } static inline hash_t Hash(const c10::Device& value) { return HashCombine(Hash(value.type()), Hash(value.index())); } static inline hash_t Hash(const c10::Layout& value) { return DataHash(&value, sizeof(value)); } static inline hash_t Hash(const c10::Scalar& value) { switch (value.type()) { case c10::ScalarType::ComplexDouble: return Hash(value.toComplexDouble()); case c10::ScalarType::Double: return Hash(value.toDouble()); case c10::ScalarType::Long: return Hash(value.toLong()); case c10::ScalarType::Bool: return Hash(value.toBool()); default: TORCH_INTERNAL_ASSERT(false, "Unknown scalar type.", value.type()); } } static inline hash_t TensorHash(const at::Tensor& tensor) { at::Tensor ctensor = tensor.contiguous(); int64_t size = ctensor.numel() * ctensor.element_size(); switch (ctensor.scalar_type()) { case at::ScalarType::Bool: return DataHash(ctensor.const_data_ptr(), size); case at::ScalarType::Byte: return DataHash(ctensor.const_data_ptr(), size); case at::ScalarType::Char: return DataHash(ctensor.const_data_ptr(), size); case at::ScalarType::Short: return DataHash(ctensor.const_data_ptr(), size); case at::ScalarType::Int: return DataHash(ctensor.const_data_ptr(), size); case at::ScalarType::Long: return DataHash(ctensor.const_data_ptr(), size); case at::ScalarType::Float: return DataHash(ctensor.const_data_ptr(), size); case at::ScalarType::Double: return DataHash(ctensor.const_data_ptr(), size); case at::ScalarType::BFloat16: return DataHash(ctensor.const_data_ptr(), size); case at::ScalarType::Half: return DataHash(ctensor.const_data_ptr(), size); case at::ScalarType::ComplexFloat: return DataHash(ctensor.const_data_ptr>(), size); case at::ScalarType::ComplexDouble: return DataHash(ctensor.const_data_ptr>(), size); case at::ScalarType::UInt16: return DataHash(ctensor.const_data_ptr(), size); case at::ScalarType::UInt32: return DataHash(ctensor.const_data_ptr(), size); case at::ScalarType::UInt64: return DataHash(ctensor.const_data_ptr(), size); default: TORCH_INTERNAL_ASSERT( false, "Unsupported scalar type:", ctensor.scalar_type()); } } static inline hash_t Hash(const std::string& value) { return DataHash(value.data(), value.size()); } static inline hash_t Hash(const std::string_view& value) { return DataHash(value.data(), value.size()); } static inline hash_t Hash(const at::Generator& value) { return TensorHash(value.get_state()); } // Taken from glibc's implementation of hashing optionals, // we want to include a contribution to the hash to distinguish // cases where one or another option was null, but we hope it doesn't // collide with an actually scalar value. // // Use an arbitrary randomly-selected 64-bit integer rather than a // small constant that we then hash at runtime so we don't have to // repeatedly hash a constant at runtime. // NOLINTNEXTLINE(*-narrowing-conversions) static const int64_t kNullOpt = 0x8655d738f3678dda; // Hashing for std::optional types contributes to hash // for optionals with null value, important to distinguish // between and cases template hash_t Hash(const std::optional& value) { if (value.has_value()) { return Hash(value.value()); } else { return kNullOpt; } } // Hashing of containers // Forward declare to allow hashes of vectors of vectors to work. template hash_t ContainerHash(const T& values); template hash_t Hash(const std::vector& values) { return ContainerHash(values); } // Need a special case for std::optional? template hash_t Hash(const std::optional>& value) { if (value.has_value()) { return ContainerHash(value.value()); } else { return kNullOpt; } } template hash_t Hash(const std::set& values) { return ContainerHash(values); } template hash_t Hash(const std::pair& values) { return HashCombine(Hash(values.first), Hash(values.second)); } static inline hash_t Hash(const hash_t& value) { return value; } template hash_t Hash(c10::ArrayRef values) { return ContainerHash(values); } template hash_t ContainerHash(const T& values) { hash_t h(static_cast(0x85ebca77c2b2ae63)); for (const auto& value : values) { h = HashCombine(h, Hash(value)); } return h; } // Varargs hashing template hash_t MHash() { return hash_t(static_cast(0x165667b19e3779f9)); } template hash_t MHash(T value, Targs... Fargs) { return HashCombine(Hash(value), MHash(Fargs...)); } } // namespace torch::lazy