#pragma once #include #include #include #include #include #include namespace torch::cuda::CUDAPluggableAllocator { using MallocFuncType = void*(size_t, int, cudaStream_t); using FreeFuncType = void(void*, size_t, int, cudaStream_t); // A CUDAPluggableAllocatorDeleterContext object is used as the `ctx` // argument for DataPtr. We need context because a user can use // multiple allocators in the same PyTorch program, and // the allocators can have different free functions, such as: // free, cudaFree, cudaFreeAsync, ncclMemFree etc. struct TORCH_CUDA_CPP_API CUDAPluggableAllocatorDeleterContext { explicit CUDAPluggableAllocatorDeleterContext( std::function free_fn, void* data, size_t size, int device, cudaStream_t stream); void free(); private: std::function free_fn_; void* data_; size_t size_; int device_; cudaStream_t stream_{}; }; #if defined(USE_ROCM) using streamType = c10::hip::HIPStream; #else using streamType = c10::cuda::CUDAStream; #endif TORCH_CUDA_CPP_API std::shared_ptr< c10::cuda::CUDACachingAllocator::CUDAAllocator> getCurrentAllocator(); TORCH_CUDA_CPP_API std::shared_ptr< c10::cuda::CUDACachingAllocator::CUDAAllocator> createCustomAllocator( std::function alloc_fn, std::function free_fn); TORCH_CUDA_CPP_API void changeCurrentAllocator( const std::shared_ptr& allocator); struct _AllocationMetadata { _AllocationMetadata(); _AllocationMetadata( size_t size, c10::DeviceIndex device_idx, cudaStream_t stream); size_t size; c10::DeviceIndex device_idx; cudaStream_t stream{}; }; struct TORCH_CUDA_CPP_API CUDAPluggableAllocator : public c10::cuda::CUDACachingAllocator::CUDAAllocator { CUDAPluggableAllocator( std::function alloc_fn, std::function free_fn); CUDAPluggableAllocator(CUDAPluggableAllocator& other); CUDAPluggableAllocator(CUDAPluggableAllocator&& other) = delete; CUDAPluggableAllocator& operator=(const CUDAPluggableAllocator& other) = delete; CUDAPluggableAllocator& operator=(CUDAPluggableAllocator&& other) = delete; ~CUDAPluggableAllocator() override = default; void set_init_fn(std::function init_fn); void set_reset_fn(std::function reset_fn); void set_memory_fraction_fn( std::function memory_fraction_fn); void set_base_alloc_fn(std::function base_alloc_fn); void set_record_stream_fn( std::function record_stream_fn); void set_begin_allocate_to_pool( std::function< void(int, c10::cuda::MempoolId_t, std::function)> capture_begin_fn); void set_end_allocate_to_pool_fn( std::function capture_about_to_end_fn); void set_release_pool( std::function capture_destroy_fn); void* malloc(size_t size, c10::DeviceIndex device, cudaStream_t stream); c10::DataPtr allocate(size_t size) override; c10::DeleterFnPtr raw_deleter() const override; void* raw_alloc(size_t nbytes) override; void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) override; void raw_delete(void* ptr) override; void init(int device_count) override; bool initialized() override; double getMemoryFraction(c10::DeviceIndex device) override; void setMemoryFraction(double fraction, c10::DeviceIndex device) override; void emptyCache(c10::cuda::MempoolId_t mempool_id = {0, 0}) override; void enable(bool) override {} bool isEnabled() const override { return true; } void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override; void* getBaseAllocation(void* ptr, size_t* size) override; void recordStream(const c10::DataPtr&, streamType stream) override; c10::CachingDeviceAllocator::DeviceStats getDeviceStats( c10::DeviceIndex device) override; void resetAccumulatedStats(c10::DeviceIndex device) override; void resetPeakStats(c10::DeviceIndex device) override; c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot( c10::cuda::MempoolId_t mempool) override; void beginAllocateToPool( c10::DeviceIndex device, c10::cuda::MempoolId_t mempool_id, std::function) override; void endAllocateToPool( c10::DeviceIndex device, c10::cuda::MempoolId_t mempool_id) override; void releasePool(c10::DeviceIndex device, c10::cuda::MempoolId_t mempool_id) override; std::shared_ptr getIpcDevPtr(std::string handle) override; c10::cuda::CUDACachingAllocator::ShareableHandle shareIpcHandle( void*) override; void recordHistory( bool enabled, c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder, size_t alloc_trace_max_entries, c10::cuda::CUDACachingAllocator::RecordContext when, bool clearHistory) override; void attachOutOfMemoryObserver( c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer) override; void attachAllocatorTraceTracker( c10::cuda::CUDACachingAllocator::AllocatorTraceTracker tracker) override; std::shared_ptr getCheckpointState(c10::DeviceIndex device, at::cuda::MempoolId_t id) override; c10::cuda::CUDACachingAllocator::CheckpointDelta setCheckpointPoolState( c10::DeviceIndex device, std::shared_ptr pps) override; void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) override; cudaError_t memcpyAsync( void* dst, int dstDevice, const void* src, int srcDevice, size_t count, cudaStream_t stream, bool p2p_enabled) override; std::string name() override; void copy_data(void* dest, const void* src, std::size_t count) const final; protected: std::function alloc_fn_; std::function free_fn_; std::function init_fn_; std::function reset_fn_; std::function memory_fraction_fn_; std::function base_alloc_fn_; std::function record_stream_fn_; std::function< void(int, c10::cuda::MempoolId_t, std::function)> begin_allocate_to_pool_fn_; std::function end_allocate_to_pool_fn_; std::function relase_pool_fn_; std::mutex allocator_mutex_; // We do the bookkeeping here in order to simplify custom allocators std::unordered_map allocation_metadata_; bool initialized_ = false; }; } // namespace torch::cuda::CUDAPluggableAllocator