#include <torch/csrc/autograd/input_buffer.h>

#include <ATen/CachedTensorUtils.h>
#include <ATen/LegacyBatchedTensorImpl.h>
#include <ATen/SparseCsrTensorUtils.h>
#include <ATen/TensorOperators.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <ATen/core/grad_mode.h>
#include <ATen/native/SparseTensorUtils.h>

#include <c10/core/DeviceGuard.h>
#include <c10/core/Event.h>
#include <c10/core/StreamGuard.h>
#include <optional>

#include <cstddef>
#include <utility>
#include <vector>

namespace torch::autograd {

namespace {
// look what you made me do >.<
// Divergent paths for per-Impl stream recording that leak implementation
// details of the impls should not be needed here.
// See https://github.com/pytorch/pytorch/issues/60306
// TODO: clean this up when https://github.com/pytorch/pytorch/issues/60306 is
// improved
void record_stream_any_impl(Variable& var, const c10::Stream& stream) {
  // NOLINTNEXTLINE(bugprone-unchecked-optional-access)

  if (stream.device_index() != var.device().index()) {
    return;
  }

  const auto guard = c10::impl::VirtualGuardImpl(device_of(var).value().type());

  if (C10_UNLIKELY(at::isBatchedTensor(var))) {
    auto* impl = at::maybeGetBatchedImpl(var);
    if (impl) {
      guard.recordDataPtrOnStream(impl->value().storage().data_ptr(), stream);
    } else {
      TORCH_INTERNAL_ASSERT(false, "Expected batched tensor");
    }
  } else {
    switch (var.layout()) {
      case c10::kSparseCsr:
      case c10::kSparseCsc:
      case c10::kSparseBsr:
      case c10::kSparseBsc: {
        auto* impl = at::sparse_csr::get_sparse_csr_impl(var);
        guard.recordDataPtrOnStream(
            impl->values().storage().data_ptr(), stream);
        guard.recordDataPtrOnStream(
            impl->compressed_indices().storage().data_ptr(), stream);
        guard.recordDataPtrOnStream(
            impl->plain_indices().storage().data_ptr(), stream);
        break;
      }
      case c10::kSparse: {
        auto* impl = at::sparse::get_sparse_impl(var);
        guard.recordDataPtrOnStream(
            impl->values().storage().data_ptr(), stream);
        guard.recordDataPtrOnStream(
            impl->indices().storage().data_ptr(), stream);
        break;
      }
      case c10::kStrided:
        guard.recordDataPtrOnStream(var.storage().data_ptr(), stream);
        break;
      default:
        TORCH_INTERNAL_ASSERT(
            false, "Unknown layout in record_stream_any_impl");
    }
  }
}

bool can_accumulate_inplace(const Variable& v) {
  return (
      // `v` is a "vanilla" Tensor
      !(at::isTensorSubclassLike(v) || v._is_zerotensor() || v.is_nested()) &&

      // with a favorable memory layout
      v.is_non_overlapping_and_dense() &&

      // and we hold the last reference
      at::caching::adjusted_use_count(v) == 1 && v.has_storage() &&
      v.storage().use_count() == 1);
}
} // anonymous namespace

static void accumulate(
    std::vector<Variable>& buffer,
    const size_t pos,
    Variable&& var) {
  TORCH_INTERNAL_ASSERT(pos < buffer.size());
  auto& old_var = buffer[pos];
  // If we hold the last reference to `old_var` AND its storage we will try to
  // repurpose it to store the output. (Or, if `old_var` is sparse then `var`
  // becomes the candidate output Tensor.) We only do this if:
  //  1) GradMode is disabled since Autograd has special handling for inplace
  //     mutation which we don't want to trigger.
  //
  //  2) We hold the last reference.
  //     (Both `.use_count` and `.storage().use_count()` are one)
  //
  //  3) The candidate tensor is a contiguous, non-overlapping, dense, and
  //     otherwise stock standard Tensor.
  //
  //  4) The candidate is mutable. Currently only ZeroTensors are immutable.
  //
  //  5) The other Tensor is not a Tensor subclass (except sparse), since
  //     it's hard to predict the semantics of arbitrary subclass behavior.

  // NOLINTNEXTLINE(bugprone-branch-clone)
  if (at::GradMode::is_enabled()) {
    buffer[pos] = old_var + var;
  } else if (
      // ATen doesn't route sparse additions correctly...
      old_var.is_sparse() || old_var.is_sparse_csr()) {
    if (can_accumulate_inplace(var)) {
      buffer[pos] = var.add_(old_var);
    } else {
      buffer[pos] = var + old_var;
    }
  } else if (
      can_accumulate_inplace(old_var) && !at::isTensorSubclassLike(var)) {
    buffer[pos] = old_var.add_(var);
  } else {
    buffer[pos] = old_var + var;
  }
}

// Note: [Stream sync contract when dealing with multi-deviced-ness]
//
// An operator can deal with multiple devices, e.g. if it does a device
// transfer, etc. However, for the purpose of stream synchronization, the engine
// is only aware of single canonical device/stream for each autograd Node.
//
// For the proper synchronization, the Node author should make sure of the
// following:
//
// 1) A node consuming a gradient should wait on the canonical stream before
//    using it.
// 2) A node producing a gradient should have it ready on the canonical
//    stream during node execution.
//

// Note: [Autograd Producer-Consumer Stream Syncs]
//
// The producer-consumer stream syncs are partially handled in this method
// and partially handled in the engine prior to the consumer's execution.
// The logic here is mainly responsible for handling the synchronization needed
// for accumulation and recording the event that the consumer should wait on
// later. The corresponding wait and record_stream happens in the engine.
//
// First producer
// ==============
// There are several things we need to do upon seeing the first producer:
// 1) Determine the accumulation stream (which may or may not be used):
//    case A) var's device matches consumer node's canonical device
//            (The producer node's canonical device may or may not match)
//            -> accumulator stream = consumer stream
//    case B) var's device matches producer node's canonical device
//            and does not match consumer node's canonical device
//            -> accumulator stream = producer stream
//    case C) var device matches neither
//            -> accumulator stream = var device's current stream
//            See Note [Stream sync contract when dealing with
//            multi-deviced-ness]
// 2) Because we are the first producer, there's no accumulation necessary.
//    Just move var into the buffer.
// 3) Update the ready_events and streams for the current position.**
//    ready_events are events you need to wait for to ensure the corresponding
//    buffers are ready. The events are updated as we accumulate into the
//    buffer.
//
// Nth producer
// ============
// 1) Synchronize for accumulation. Accumulation operates on both the new
//   incoming gradient and the existing gradient in the buffer.
//   (i) wait stream and (ii) record stream to make sure both are ready to be
//   used on the accumulation stream.
// 2) Accumulate on the accumulation stream
// 3) Update the ready event and stream for the current position.**
//
// **As an optimization, we avoid creating and recording an event if we
// know that we won't need to wait on it, saving on the order of microseconds.
//
void InputBuffer::add(
    size_t pos,
    Variable&& var,
    const std::optional<c10::Stream>& opt_producer_stream_,
    const std::optional<c10::Stream>& opt_consumer_stream_) {
  TORCH_INTERNAL_ASSERT(pos < buffer.size());

  if (!var.defined()) {
    return;
  }
  const auto device = var.device();
  const auto device_type = device.type();
  bool is_accelerator = at::accelerator::isAccelerator(device.type());
  //
  // Non-accelerator case
  //
  if (!is_accelerator) {
    if (!buffer[pos].defined()) {
      buffer[pos] = std::move(var);
    } else {
      c10::OptionalDeviceGuard device_guard{device};
      accumulate(buffer, pos, std::move(var));
    }
    return;
  }
  // Handle the case where var is on an accelerator but producer node has no
  // canonical stream, e.g. this can happen if forward is DtoH
  const std::optional<c10::Stream>& opt_producer_stream =
      (opt_producer_stream_.has_value()
           ? opt_producer_stream_
           : std::optional<c10::Stream>(
                 at::accelerator::getCurrentStream(device.index())));

  // opt_consumer_stream is always non-null when is_accelerator is true
  // when InputBuffer is used in the engine. InputBuffer is also called
  // elsewhere however! (e.g. other engine implementations)
  const std::optional<c10::Stream>& opt_consumer_stream =
      (opt_consumer_stream_.has_value()
           ? opt_consumer_stream_
           : std::optional<c10::Stream>(
                 at::accelerator::getCurrentStream(device.index())));

  TORCH_INTERNAL_ASSERT(opt_consumer_stream && opt_producer_stream);

  // See Note: [Autograd Producer-Consumer Stream Syncs]
  if (!opt_accum_streams[pos].has_value()) {
    // [ First producer ]
    TORCH_INTERNAL_ASSERT(!buffer[pos].defined());
    // 1)
    if (opt_consumer_stream->device() == device) {
      // Case A
      opt_accum_streams[pos] = opt_consumer_stream;
      if (*opt_consumer_stream != *opt_producer_stream) {
        // We will end up doing record_stream on the accumulation stream
        // (which is the consumer stream) later, but we also need to do
        // it here in case we don't end up accumulating.
        record_stream_any_impl(var, *opt_consumer_stream);
      }
    } else if (opt_producer_stream->device() == device) {
      // Case B
      opt_accum_streams[pos] = opt_producer_stream;
    } else {
      // Case C
      opt_accum_streams[pos] =
          at::accelerator::getCurrentStream(device.index());
    }
    // 2)
    buffer[pos] = std::move(var);
    // 3)
    auto& opt_accum_stream = opt_accum_streams[pos];
    TORCH_INTERNAL_ASSERT(opt_accum_stream.has_value());
    if (*opt_consumer_stream != *opt_producer_stream ||
        *opt_accum_stream != *opt_producer_stream) {
      // Either the consumer or accum stream waits for the producer
      // stream depending on whether accumulation is needed.
      auto event = c10::Event{device_type};
      event.record(*opt_producer_stream);
      ready_events[pos] = std::move(event);
    }
    ready_streams[pos] = opt_producer_stream;
  } else {
    // [ Nth producer ]
    auto accum_stream = opt_accum_streams[pos];
    auto& ready_event = ready_events[pos];
    auto& ready_stream = ready_streams[pos];
    TORCH_INTERNAL_ASSERT(accum_stream && ready_stream);
    // 1)
    if (*accum_stream != *opt_producer_stream) {
      auto event = c10::Event{device_type};
      event.record(*opt_producer_stream);
      accum_stream->wait(event);
      record_stream_any_impl(var, *accum_stream);
    }
    if (*accum_stream != *ready_stream) {
      TORCH_INTERNAL_ASSERT(ready_event);
      accum_stream->wait(*ready_event);
      // This is redundant for case A, but needed for case C
      record_stream_any_impl(buffer[pos], *accum_stream);
    }
    // 2)
    c10::OptionalStreamGuard stream_guard{accum_stream};
    accumulate(buffer, pos, std::move(var));
    // 3)
    if (*opt_consumer_stream != *accum_stream) {
      // Only the consumer stream needs to wait for this event
      auto event = c10::Event{device_type};
      event.record(*accum_stream);
      ready_events[pos] = std::move(event);
    }
    ready_streams[pos] = accum_stream;
  }
}

auto InputBuffer::variables(InputBuffer&& g) -> std::vector<Variable> {
  std::vector<Variable> result = std::move(g.buffer);
  return result;
}

} // namespace torch::autograd
