#ifndef MPSImageWrapper_h #define MPSImageWrapper_h #import #import #include namespace at { namespace native { namespace metal { class API_AVAILABLE(ios(11.0), macos(10.13)) MPSImageWrapper { public: MPSImageWrapper(IntArrayRef sizes); ~MPSImageWrapper(); void copyDataFromHost(const float* inputData); void copyDataToHost(float* hostData); void allocateStorage(IntArrayRef sizes); void allocateTemporaryStorage( IntArrayRef sizes, MetalCommandBuffer* commandBuffer); void setCommandBuffer(MetalCommandBuffer* buffer); MetalCommandBuffer* commandBuffer() const; void setImage(MPSImage* image); MPSImage* image() const; id buffer() const; void synchronize(); void prepare(); void release(); private: std::vector _imageSizes; MPSImage* _image = nil; id _buffer = nil; MetalCommandBuffer* _commandBuffer = nil; id _delegate = nil; }; } // namespace metal } // namespace native } // namespace at #endif /* MPSImageWrapper_h */