#pragma once #include namespace torch::autograd::utils { // Helper functions to enforce the "Gradient Layout Contract" described in // torch/csrc/autograd/functions/accumulate_grad.h. // Checks if grad obeys the contract with variable. inline bool obeys_layout_contract( const at::Tensor& grad, const at::Tensor& variable) { TORCH_INTERNAL_ASSERT(!grad.is_sparse()); TORCH_INTERNAL_ASSERT(!grad.is_sparse_csr()); TORCH_INTERNAL_ASSERT(!variable.is_sparse_csr()); // NOLINTNEXTLINE(bugprone-branch-clone) if (variable.is_nested()) { // TODO: Nested Tensor does not have an implementation of detach. The // current implementation of nested tensor likely does obey the gradient // contract and should return true, but this would likely change in the // future return false; } else if (variable.is_sparse()) { // Gradient Layout Contract is not applicable for sparse layouts return false; } else if (variable.is_non_overlapping_and_dense()) { // Only look at stride for dimensions that are not of size 1. const auto& grad_sizes = grad.sym_sizes(); const auto& grad_strides = grad.sym_strides(); const auto& variable_strides = variable.sym_strides(); for (const auto idx : c10::irange(grad_sizes.size())) { if (grad_sizes[idx] != 1) { if (grad_strides[idx] != variable_strides[idx]) { return false; } } else { // This should not be needed but we don't check if a Tensor has views // before stashing it. And 0-strided Tensors of size 1 are actually // views for ops like cat. // TODO: Actually detect views in the accumulateGrad function so that // this Tensor is not considered at all. if (grad_strides[idx] == 0) { return false; } } } return true; } else { return grad.is_contiguous(at::MemoryFormat::Contiguous); } } // Creates a clone of new_grad that obeys the contract with variable. // The clone should attach to new_grad's history if GradMode::is_enabled(). inline at::Tensor clone_obey_contract( const at::Tensor& new_grad, const at::Tensor& variable) { if (variable.is_non_overlapping_and_dense()) { // (1) // Does this dicey-looking sequence attach the result to new_grad's // history if GradMode::is_enabled()? Yes, and @alband says it should. return std::move(new_grad .new_empty_strided_symint( variable.sym_sizes(), variable.sym_strides(), variable.options().memory_format(std::nullopt)) .copy_(new_grad)); } else { // (2) return new_grad.clone(at::MemoryFormat::Contiguous); } } } // namespace torch::autograd::utils