#pragma once #include #include #if AT_MKLDNN_ENABLED() #include #ifndef IDEEP_PREREQ // Please find definitions of version numbers in ideep.hpp #if defined(IDEEP_VERSION_MAJOR) && defined(IDEEP_VERSION_MINOR) && \ defined(IDEEP_VERSION_PATCH) && defined(IDEEP_VERSION_REVISION) #define IDEEP_PREREQ(major, minor, patch, revision) \ (((IDEEP_VERSION_MAJOR << 16) + (IDEEP_VERSION_MINOR << 8) + \ (IDEEP_VERSION_PATCH << 0)) >= \ ((major << 16) + (minor << 8) + (patch << 0)) && \ (IDEEP_VERSION_REVISION >= revision)) #else #define IDEEP_PREREQ(major, minor, patch, revision) 0 #endif #endif namespace at { namespace native { // Mapping ScalarType to ideep tensor data_type TORCH_API ideep::tensor::data_type get_mkldnn_dtype(ScalarType type); static inline ideep::tensor::data_type get_mkldnn_dtype(const Tensor& t) { return get_mkldnn_dtype(t.scalar_type()); } TORCH_API int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor); TORCH_API at::Tensor mkldnn_tensor_from_data_ptr( void* data_ptr, at::IntArrayRef dims, at::ScalarType dtype, at::Device device, const uint8_t* opaque_metadata, int64_t opaque_metadata_size); // Construct aten MKL-DNN tensor given an ideep tensor TORCH_API Tensor new_with_itensor_mkldnn(ideep::tensor&& it, std::optional dtype, std::optional device); // Retrieve `ideep::tensor` from MKL-DNN tensor TORCH_API ideep::tensor& itensor_from_mkldnn(const Tensor& mkldnn_tensor); TORCH_API int64_t nbytes_from_mkldnn(const Tensor& mkldnn_tensor); // Construct an `ideep::tensor` "view" from dense tensor, note the // ideep::tensor will share the underlying buffer TORCH_API ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data_ptr=false); // Construct an `ideep::tensor` "view" from dense tensor using given desc, note // the ideep::tensor will share the underlying buffer TORCH_API ideep::tensor itensor_view_from_dense( const at::Tensor& tensor, const ideep::tensor::desc& desc); // Helper function for getting an ideep tensor out of an aten Tensor or MKL-DNN tensor. TORCH_API ideep::tensor itensor_from_tensor(const Tensor& tensor, bool from_const_data_ptr=false); // Set MKLDNN verbose level TORCH_API int set_verbose(int level); }} #endif // AT_MKLDNN_ENABLED