#pragma once #include #include #include #include #include #include #include #include namespace torch::jit::fuser { // type information needed by the compiler for input/outputs // contiguity[i] is true if the dim i is contiguous with dim i + 1. // contiguity.back() == true means strides.back() == 1. struct TORCH_API TensorDesc { at::ScalarType scalar_type; std::vector contiguity; TensorDesc(const at::ScalarType& type, const std::vector& contiguity) : scalar_type{type}, contiguity{contiguity} { if (contiguity.empty()) { nDim_ = 0; } else { nDim_ = std::count(contiguity.begin(), contiguity.end(), false) + (lastIsContiguous() ? 1 : 0); } } // Delegating constructors TensorDesc( const at::ScalarType& type, const at::IntArrayRef& sizes, const at::IntArrayRef& strides) : TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {} TensorDesc(const at::Tensor& t) : TensorDesc(t.scalar_type(), t.sizes(), t.strides()) {} TensorDesc(const c10::TensorTypePtr& type) : TensorDesc( type->scalarType().value(), type->sizes().concrete_sizes().value(), type->strides().concrete_sizes().value()) {} // number of dimensions after contiguity compression size_t nDim() const { return nDim_; } // True iff innermost stride is 1 bool lastIsContiguous() const { return (contiguity.empty() || contiguity.back()); } static std::vector findContiguous( const at::IntArrayRef& sizes, const at::IntArrayRef& strides) { AT_ASSERT(sizes.size() == strides.size()); std::vector cont(sizes.size()); for (size_t i = 0; i < sizes.size(); ++i) { const auto expected_stride = (i + 1 < sizes.size()) ? sizes[i + 1] * strides[i + 1] : 1; cont[i] = (strides[i] == expected_stride); } return cont; } bool operator==(const TensorDesc& desc) const { return scalar_type == desc.scalar_type && contiguity == desc.contiguity; } bool operator!=(const TensorDesc& desc) const { return !(*this == desc); } static size_t hash(const TensorDesc& spec) { return c10::get_hash( spec.scalar_type, spec.nDim_, std::hash>{}(spec.contiguity)); } private: size_t nDim_; }; inline std::ostream& operator<<(std::ostream& out, const TensorDesc& d) { out << d.scalar_type << "["; for (const auto b : d.contiguity) out << b << ";"; out << "]"; return out; } } // namespace torch::jit::fuser