#include #include namespace at::native::metal { class MPSImageWrapper; class MetalTensorImplStorage final { class Impl; public: MetalTensorImplStorage() = default; MetalTensorImplStorage(const std::vector& sizes); MetalTensorImplStorage( const std::vector& sizes, const std::vector& strides); ~MetalTensorImplStorage() = default; MetalTensorImplStorage(MetalTensorImplStorage&&) = default; MetalTensorImplStorage& operator=(MetalTensorImplStorage&&) = default; MetalTensorImplStorage(const MetalTensorImplStorage&) = default; MetalTensorImplStorage& operator=(const MetalTensorImplStorage&) = default; friend std::ostream& operator<<( std::ostream& output, const MetalTensorImplStorage& mt); bool defined() const; IntArrayRef sizes() const; IntArrayRef strides() const; int64_t dim() const; int64_t numel() const; void set_data_from_host(const float* inputData); void copy_data_to_host(float* host); MPSImageWrapper* texture() const; private: std::shared_ptr impl(); std::shared_ptr impl() const; std::shared_ptr _impl; }; } // namespace at::native::metal