#pragma once #include #include #include #include #include #include #include #include namespace at::native::onednn { TORCH_XPU_API dnnl::memory make_onednn_memory( dnnl::memory::desc md, dnnl::engine& engine, void* ptr); // Keep non-static and non-inline bool set_onednn_verbose(int level); // GpuEngineManager singleton struct TORCH_XPU_API GpuEngineManager { static GpuEngineManager& Instance(); // Singleton dnnl::engine& get_engine(const Device& device) { TORCH_INTERNAL_ASSERT(device.type() == kXPU); TORCH_INTERNAL_ASSERT(device.index() < c10::xpu::device_count()); return *engine_pool[device.index()]; } GpuEngineManager(GpuEngineManager const&) = delete; GpuEngineManager& operator=(GpuEngineManager const&) = delete; GpuEngineManager(GpuEngineManager&&) = default; GpuEngineManager& operator=(GpuEngineManager&&) = default; protected: GpuEngineManager(); ~GpuEngineManager() = default; private: std::vector> engine_pool; }; // GpuStreamManager singleton struct TORCH_XPU_API GpuStreamManager { static GpuStreamManager& Instance(); // Singleton dnnl::stream get_stream() { auto stream = c10::xpu::getCurrentXPUStream(); auto priority = stream.priority(); auto device_index = stream.device_index(); if (stream_pool[device_index][priority].find(stream) == stream_pool[device_index][priority].end()) { stream_pool[device_index][priority][stream] = std::make_shared(dnnl::sycl_interop::make_stream( GpuEngineManager::Instance().get_engine( {c10::kXPU, device_index}), stream.queue())); } return *stream_pool[device_index][priority][stream]; } GpuStreamManager(GpuStreamManager const&) = delete; GpuStreamManager& operator=(GpuStreamManager const&) = delete; GpuStreamManager(GpuStreamManager&&) = default; GpuStreamManager& operator=(GpuStreamManager&&) = default; protected: GpuStreamManager() { c10::DeviceIndex device_count = c10::xpu::device_count(); TORCH_INTERNAL_ASSERT(device_count > 0); stream_pool.resize(device_count); } ~GpuStreamManager() = default; private: using stream_hash_map = ska::flat_hash_map>; std::vector< std::array> stream_pool; }; } // namespace at::native::onednn