#pragma once #include #include namespace at { using GeneratorFuncType = std::function; TORCH_API std::optional& GetGeneratorPrivate(); class TORCH_API _GeneratorRegister { public: explicit _GeneratorRegister(const GeneratorFuncType& func); }; TORCH_API at::Generator GetGeneratorForPrivateuse1( c10::DeviceIndex device_index); /** * This is used to register Generator to PyTorch for `privateuse1` key. * * Usage: REGISTER_GENERATOR_PRIVATEUSE1(MakeGeneratorForPrivateuse1) * * class CustomGeneratorImpl : public c10::GeneratorImpl { * CustomGeneratorImpl(DeviceIndex device_index = -1); * explicit ~CustomGeneratorImpl() override = default; * ... * }; * * at::Generator MakeGeneratorForPrivateuse1(c10::DeviceIndex id) { * return at::make_generator(id); * } */ #define REGISTER_GENERATOR_PRIVATEUSE1(GeneratorPrivate) \ static auto temp##GeneratorPrivate = at::_GeneratorRegister(GeneratorPrivate); } // namespace at