#pragma once #include namespace torch::aot_inductor { template struct ThreadLocalCachedOutputTensor; template <> struct ThreadLocalCachedOutputTensor { explicit ThreadLocalCachedOutputTensor(const RAIIAtenTensorHandle&) {} void copy_data_from(const RAIIAtenTensorHandle& handle) { throw std::runtime_error("can't happen"); } AtenTensorHandle tensor() const { throw std::runtime_error("can't happen"); } }; template <> struct ThreadLocalCachedOutputTensor { explicit ThreadLocalCachedOutputTensor(const AtenTensorHandle&) {} void copy_data_from(const AtenTensorHandle& handle) { throw std::runtime_error("can't happen"); } AtenTensorHandle tensor() const { throw std::runtime_error("can't happen"); } }; template <> struct ThreadLocalCachedOutputTensor { explicit ThreadLocalCachedOutputTensor(const ConstantHandle&) {} void copy_data_from(const ConstantHandle& handle) { throw std::runtime_error("can't happen"); } AtenTensorHandle tensor() const { throw std::runtime_error("can't happen"); } }; template struct ThreadLocalCachedOutputTensor> { explicit ThreadLocalCachedOutputTensor(const ArrayRefTensor& t) { realloc(t); } void copy_data_from(const ArrayRefTensor& t) { if (t.numel() > capacity_) { realloc(t); } std::copy(t.data(), t.data() + t.numel(), storage_.get()); } AtenTensorHandle tensor() const { return tensor_.get(); } private: void realloc(const ArrayRefTensor& t) { capacity_ = t.numel(); // NOLINTNEXTLINE(*arrays*) storage_ = std::make_unique(t.numel()); AtenTensorHandle handle = nullptr; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob( storage_.get(), t.sizes().size(), t.sizes().data(), t.strides().data(), 0, aoti_torch_dtype>(), t.device_type(), t.device_idx(), &handle)); tensor_ = handle; } // NOLINTNEXTLINE(*arrays*) std::unique_ptr storage_; int64_t capacity_ = 0; RAIIAtenTensorHandle tensor_; }; template struct ThreadLocalCachedOutputArray; // Just needs to compile, doesn't need to do anything. template <> struct ThreadLocalCachedOutputArray { explicit ThreadLocalCachedOutputArray(const RAIIAtenTensorHandle&) { throw std::runtime_error("can't happen"); } // Not supported yet! We would need to put contiguous() or // expect_contiguous() into the ABI. void copy_data_from(const RAIIAtenTensorHandle&) { throw std::runtime_error("can't happen"); } template ArrayRefTensor arrayref_tensor() const { throw std::runtime_error("can't happen"); } }; // Just needs to compile, doesn't need to do anything. template <> struct ThreadLocalCachedOutputArray { explicit ThreadLocalCachedOutputArray(const ConstantHandle&) { throw std::runtime_error("can't happen"); } // Not supported yet! We would need to put contiguous() or // expect_contiguous() into the ABI. void copy_data_from(const ConstantHandle&) { throw std::runtime_error("can't happen"); } template ArrayRefTensor arrayref_tensor() const { throw std::runtime_error("can't happen"); } }; template struct ThreadLocalCachedOutputArray> { explicit ThreadLocalCachedOutputArray(const ArrayRefTensor& t) {} template < typename U, std::enable_if_t< std::is_same_v, std::remove_const_t>, bool> = true> ArrayRefTensor arrayref_tensor() const { return tensor_; } void copy_data_from(const ArrayRefTensor& t) { if (t.numel() > capacity_) { capacity_ = t.numel(); // NOLINTNEXTLINE(*arrays*) storage_ = std::make_unique(capacity_); } std::copy(t.data(), t.data() + t.numel(), storage_.get()); tensor_ = t; tensor_.set_arrayref(MiniArrayRef(storage_.get(), t.numel())); } private: // NOLINTNEXTLINE(*arrays*) std::unique_ptr storage_; uint32_t capacity_ = 0; ArrayRefTensor tensor_; }; } // namespace torch::aot_inductor