#if !defined(C10_MOBILE) && !defined(ANDROID) #pragma once #include #include #include // Forward declare DynamicLibrary namespace at { struct DynamicLibrary; } namespace torch::inductor { using TensorConstantMap = std::unordered_map; class TORCH_API AOTIModelContainerRunner { public: AOTIModelContainerRunner() = delete; AOTIModelContainerRunner(const AOTIModelContainerRunner& other) = delete; AOTIModelContainerRunner(AOTIModelContainerRunner&& other) = delete; AOTIModelContainerRunner& operator=(const AOTIModelContainerRunner& other) = delete; AOTIModelContainerRunner& operator=(AOTIModelContainerRunner&& other) = delete; virtual ~AOTIModelContainerRunner(); std::vector run( const std::vector& inputs, void* stream_handle = nullptr); // boxed_run will steal the ownership of the input tensors std::vector boxed_run( std::vector&& inputs, void* stream_handle = nullptr); std::unordered_map getConstantNamesToOriginalFQNs() const; std::unordered_map getConstantNamesToDtypes() const; const std::unordered_map extract_constants_map( bool use_inactive) const; void update_inactive_constant_buffer(const TensorConstantMap& const_map); void update_constant_buffer( std::unordered_map& tensor_map, bool use_inactive, bool validate_full_updates, bool user_managed = false); void update_constant_buffer( const TensorConstantMap& const_map, bool use_inactive, bool validate_full_updates, bool user_managed = false); void run_const_fold( bool use_inactive, AOTInductorStreamHandle cuda_stream_handle = nullptr); void swap_constant_buffer(); void free_inactive_constant_buffer(); std::vector get_call_spec(); protected: AOTIModelContainerRunner( const std::string& model_so_path, size_t num_models, const std::string& device_str, const std::string& cubin_dir, const bool run_single_threaded); virtual std::vector run_impl( std::vector& input_handles, void* stream_handle); std::unique_ptr model_so_; decltype(&AOTInductorModelContainerCreateWithDevice) create_func_{nullptr}; decltype(&AOTInductorModelContainerDelete) delete_func_{nullptr}; decltype(&AOTInductorModelContainerGetNumOutputs) get_num_outputs_func_{ nullptr}; decltype(&AOTInductorModelContainerRun) run_func_{nullptr}; decltype(&AOTInductorModelContainerGetNumConstants) get_num_constants_func_{ nullptr}; decltype(&AOTInductorModelContainerGetConstantName) get_constant_name_func_{ nullptr}; decltype(&AOTInductorModelContainerGetConstantOriginalFQN) get_constant_original_fqn_func_{nullptr}; decltype(&AOTInductorModelContainerGetConstantDtype) get_constant_dtype_func_{ nullptr}; decltype(&AOTInductorModelContainerExtractConstantsMap) extract_constants_map_func_{nullptr}; decltype(&AOTInductorModelContainerUpdateUserManagedConstantBuffer) update_user_managed_constant_buffer_func_{nullptr}; decltype(&AOTInductorModelContainerUpdateConstantBuffer) update_constant_buffer_func_{nullptr}; decltype(&AOTInductorModelContainerUpdateInactiveConstantBuffer) update_inactive_constant_buffer_func_{nullptr}; decltype(&AOTInductorModelContainerRunConstantFolding) run_const_fold_func_{ nullptr}; decltype(&AOTInductorModelContainerSwapConstantBuffer) swap_constant_buffer_func_{nullptr}; decltype(&AOTInductorModelContainerFreeInactiveConstantBuffer) free_inactive_constant_buffer_func_{nullptr}; decltype(&AOTInductorModelContainerGetCallSpec) get_call_spec_func_{nullptr}; AOTInductorModelContainerHandle container_handle_ = nullptr; AOTIProxyExecutorHandle proxy_executor_handle_; private: std::unique_ptr proxy_executor_; }; using CreateAOTIModelRunnerFunc = std::unique_ptr (*)( const std::string& model_so_path, size_t num_models, const std::string& device_str, const std::string& bin_dir, const bool run_single_threaded); // Return a global map "device name" -> "aoti model runner create function" for // all registered in AOTI external backends TORCH_API std::unordered_map& getAOTIModelRunnerRegistry(); // To register a new external backend in AOTI one needs to create an instance of // this struct. It is not thread-safe. Because it is expected to be called // during the initialization of the program. struct TORCH_API RegisterAOTIModelRunner{RegisterAOTIModelRunner( const std::string& name, CreateAOTIModelRunnerFunc create_aoti_model_runner_fn){ getAOTIModelRunnerRegistry()[name] = create_aoti_model_runner_fn; } // namespace torch::inductor } ; } // namespace torch::inductor #endif