#pragma once #include #include #include #include #include #include namespace torch::aot_inductor { using MiniIntArrayRef = MiniArrayRef; static_assert( sizeof(MiniIntArrayRef) == sizeof(void*) + sizeof(size_t), "changing the size of MiniArrayRef breaks ABI compatibility!"); inline bool is_contiguous_strides_for_shape( int64_t ndim, const int64_t* strides_ptr, const int64_t* sizes_ptr) { int64_t z = 1; for (int64_t d = ndim - 1; d >= 0; d--) { const auto& size_d = sizes_ptr[d]; if (size_d != 1) { if (strides_ptr[d] == z) { z *= size_d; } else { return false; } } } return true; } // Shim for AOTI generated code to pretend a raw array works like an // AtenTensorHandle. template class ArrayRefTensor { public: ArrayRefTensor() = default; explicit ArrayRefTensor( MiniArrayRef arr, MiniArrayRef sizes, MiniArrayRef strides, int32_t device_type, int32_t device_idx) : arrayRef_(arr), sizes_(sizes), strides_(strides), device_type_(device_type), device_idx_(device_idx) { assert(sizes.size() == strides.size()); assert(is_contiguous_strides_for_shape( sizes.size(), strides.data(), sizes.data())); } AtenTensorHandle expensiveCopyToTensor() const { AtenTensorHandle result = nullptr; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided( sizes_.size(), sizes_.data(), strides_.data(), aoti_torch_dtype>(), device_type_, device_idx_, &result)); void* dataPtr = nullptr; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(result, &dataPtr)); std::memcpy(dataPtr, data(), numel() * sizeof(T)); return result; } // We need to look the same as RAIIAtenTensorHandle, which returns // an owning AtenTensorHandle from release(). So, we allocate one! AtenTensorHandle release() { return expensiveCopyToTensor(); } AtenTensorHandle borrowAsTensor() const { AtenTensorHandle result = nullptr; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_v2( data(), sizes_.size(), sizes_.data(), strides_.data(), 0, aoti_torch_dtype>(), device_type_, device_idx_, &result, aoti_torch_layout_strided(), nullptr, 0)); return result; } // We don't need to free any memory. void reset() {} auto sizes() const { return sizes_; } auto strides() const { return strides_; } auto device_type() const { return device_type_; } auto device_idx() const { return device_idx_; } T* data() const { return arrayRef_.data(); } auto numel() const { return arrayRef_.size(); } void set_arrayref(MiniArrayRef new_arrayref) { arrayRef_ = new_arrayref; } private: MiniArrayRef arrayRef_; // We expect generated code to have statically available sizes & // strides for us. MiniArrayRef sizes_; MiniArrayRef strides_; int32_t device_type_ = 0; int32_t device_idx_ = 0; // We continue to zero-initialize this field in case we repurpose // the space later; having predictable contents can only help. int32_t unusedDoNotRemoveForABICompatibility_ = 0; }; static_assert( sizeof(ArrayRefTensor) == 3 * sizeof(MiniIntArrayRef) + 3 * sizeof(int32_t) + (alignof(ArrayRefTensor) > 4 ? sizeof(int32_t) : 0), "changing the size of ArrayRefTensor breaks ABI compatibility!"); template inline ArrayRefTensor reinterpret_tensor_wrapper( const ArrayRefTensor& self, int64_t ndim, const int64_t* sizes_ptr, const int64_t* strides_ptr, int64_t storage_offset) { // REVIEW: we should add a way to build the DSO in debug mode during // tests so we can have checks like this! assert(is_contiguous_strides_for_shape(ndim, strides_ptr, sizes_ptr)); return ArrayRefTensor( MiniArrayRef( self.data() + storage_offset, self.numel() - storage_offset), MiniArrayRef(sizes_ptr, ndim), MiniArrayRef(strides_ptr, ndim), self.device_type(), self.device_idx()); } template inline T* get_data_ptr_wrapper(ArrayRefTensor& tensor) { return tensor.data(); } template inline T* get_data_ptr_wrapper(const MiniArrayRef& arr) { return arr.data(); } template inline const ArrayRefTensor& unwrap_raii_handle_if_needed( const ArrayRefTensor& tensor) { return tensor; } template inline ArrayRefTensor& unwrap_raii_handle_if_needed( ArrayRefTensor& tensor) { return tensor; } template inline const ArrayRefTensor& wrap_with_raii_handle_if_needed( const ArrayRefTensor& tensor) { return tensor; } template inline ArrayRefTensor& wrap_with_raii_handle_if_needed( ArrayRefTensor& tensor) { return tensor; } template inline ArrayRefTensor wrap_with_raii_handle_if_needed( ArrayRefTensor&& tensor) { return std::move(tensor); } template inline RAIIAtenTensorHandle expensive_copy_to_tensor_if_needed( const ArrayRefTensor& tensor) { return tensor.expensiveCopyToTensor(); } inline AtenTensorHandle expensive_copy_to_tensor_if_needed( AtenTensorHandle handle) { return handle; } template const T& copy_arrayref_tensor_to_tensor(const T& t) { return t; } template RAIIAtenTensorHandle copy_arrayref_tensor_to_tensor( const ArrayRefTensor& art) { return art.expensiveCopyToTensor(); } template const T& borrow_arrayref_tensor_as_tensor(const T& t) { return t; } template RAIIAtenTensorHandle borrow_arrayref_tensor_as_tensor( const ArrayRefTensor& art) { return art.borrowAsTensor(); } } // namespace torch::aot_inductor