#pragma once #include #include #include #include namespace torch::nativert { /** * This function returns a normalized version of the input device: * - For CPU devices, the returned device will have no index (i.e., the default * CPU device). * - For CUDA devices, if no index is specified, index 0 is assumed. * - For other device types, the function will raise an error. * * @param device The input c10::Device to normalize. * @return A normalized c10::Device with standardized indexing. * * @throws c10::Error If the device type is not CPU or CUDA. */ c10::Device normalizeDevice(const c10::Device& device); /** * Returns true if the two devices are the same and has the same device index * (if cuda). */ bool isSameDevice(const c10::Device& device1, const c10::Device& device2); /** * @brief A utility class for managing device placement mappings. * * The Placement class provides a way to map source devices to target devices. * It supports both explicit per-device mappings and a default device fallback. * This is the argument taken in NativeRT to map from model artifact device to * the device it should run on. */ struct TORCH_API Placement { Placement() = default; explicit Placement(std::optional defaultDevice); explicit Placement( const std::unordered_map& deviceMap, std::optional defaultDevice = std::nullopt); c10::Device getMappedDevice(const c10::Device& srcDevice) const; TORCH_API friend std::ostream& operator<<( std::ostream& os, const Placement& obj); protected: std::unordered_map deviceMap_; std::optional defaultDevice_; }; } // namespace torch::nativert