#ifndef MetalTensorImpl_h #define MetalTensorImpl_h #include #include #import #import namespace at { template struct TORCH_API MetalTensorImpl : public OpaqueTensorImpl { MetalTensorImpl( at::DispatchKeySet key_set, const caffe2::TypeMeta& data_type, c10::Device device, OpaqueHandle opaque_handle, c10::IntArrayRef sizes, c10::IntArrayRef strides) : OpaqueTensorImpl( key_set, data_type, device, opaque_handle, sizes), strides_(strides.vec()) { } // TODO: manually storing strides here is dumb IntArrayRef strides_custom() const override { return strides_; } c10::SymIntArrayRef sym_strides_custom() const override { return c10::fromIntArrayRefKnownNonNegative(strides_); } bool is_contiguous_custom(c10::MemoryFormat memory_format) const override { return true; } private: const char* tensorimpl_type_name() const override { return "MetalTensorImpl"; } SmallVector strides_; }; } // namespace at #endif /* MetalTensorImpl_h*/