/*
 * Copyright © Advanced Micro Devices, Inc. All rights reserved.
 * Copyright (C) 2024-2025, The vLLM team.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>

#include "custom_all_reduce.cuh"

// fake pointer type, must match fptr_t type in ops.h
using fptr_t = int64_t;
static_assert(sizeof(void *) == sizeof(fptr_t));

namespace aiter {

fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
                      const std::vector<std::string> &handles,
                      const std::vector<int64_t> &offsets, int64_t rank,
                      bool full_nvlink)
{
  int world_size = offsets.size();
  if (world_size > 8)
    throw std::invalid_argument("world size > 8 is not supported");
  if (world_size % 2 != 0)
    throw std::invalid_argument("Odd num gpus is not supported for now");
  if (world_size != handles.size())
    throw std::invalid_argument(
        "handles length should equal to offsets length");
  if (rank < 0 || rank >= world_size)
    throw std::invalid_argument("invalid rank passed in");

  cudaIpcMemHandle_t ipc_handles[8];
  for (int i = 0; i < world_size; i++)
  {
    std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
  }
  return (fptr_t) new vllm::CustomAllreduce(
      reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(),
      rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
}

/**
 * Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
 * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
 * because it allows transpose of contiguous slice (i.e. slicing the first
 * dimension). Currently, we require this because stride information is not
 * passed into the kernels and we treat input tensors as flat.
 *
 * Examples
 * A = torch.zeros(3, 3, 3)
 * 1. A: OK
 * 2. A[1:]: OK
 * 3. A.permute(2, 0, 1): OK
 * 4. A[1:].permute(2, 0, 1): OK
 * 5. A[None].expand(2, -1, -1, -1): Not OK
 * 6. A[:, 1:, 1:]: Not OK
 */
bool _is_weak_contiguous(torch::Tensor &t)
{
  return t.is_contiguous() ||
         (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
          t.numel() * t.element_size());
}

void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
                 cudaStream_t stream, bool open_fp8_quant)
{
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
  TORCH_CHECK(_is_weak_contiguous(out));
  switch (out.scalar_type())
  {
  case at::ScalarType::Float:
  {
    fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
                         reinterpret_cast<float *>(out.data_ptr()),
                         out.numel());
    break;
  }
  case at::ScalarType::Half:
  {
    // support case: 512 threads per block, half 8 pack, split into 8 pieces
    if (open_fp8_quant && inp.numel() % (4096 * 8) == 0)
    {
      fa->runFp8QuantKernel(stream, reinterpret_cast<half *>(inp.data_ptr()),
                          reinterpret_cast<half *>(out.data_ptr()), out.numel());
    }
    else
    {
      fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
                          reinterpret_cast<half *>(out.data_ptr()), out.numel());
    }
    break;
  }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
  case at::ScalarType::BFloat16:
  {
    fa->allreduce<nv_bfloat16>(
        stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
        reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
    break;
  }
#endif
  default:
    throw std::runtime_error(
        "custom allreduce only supports float32, float16 and bfloat16");
  }
}

void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, bool open_fp8_quant)
{
  const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
  auto stream = c10::cuda::getCurrentCUDAStream().stream();
  TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
  TORCH_CHECK_EQ(inp.numel(), out.numel());
  _all_reduce(_fa, inp, out, stream, open_fp8_quant);
}

void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
                      torch::Tensor &out)
{
  const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
  auto stream = c10::cuda::getCurrentCUDAStream().stream();

  auto input_size = inp.numel() * inp.element_size();
  TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
  TORCH_CHECK_EQ(inp.numel(), out.numel());
  TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
              "registered buffer is too small to contain the input");
  AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
                                input_size, cudaMemcpyDeviceToDevice, stream));
  _all_reduce(_fa, reg_buffer, out, stream, false);
}

void dispose(fptr_t _fa)
{
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
  delete fa;
}

int64_t meta_size() { return sizeof(vllm::Signal); }

void register_buffer(fptr_t _fa, torch::Tensor &t,
                     const std::vector<std::string> &handles,
                     const std::vector<int64_t> &offsets)
{
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
  fa->register_buffer(handles, offsets, t.data_ptr());
}

std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
    fptr_t _fa)
{
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
  auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
  auto options =
      torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
  auto handles =
      torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
  std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
  return {handles, std::move(offsets)};
}

void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
                            const std::vector<std::vector<int64_t>> &offsets)
{
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
  fa->register_graph_buffers(handles, offsets);
}

#ifdef USE_ROCM

void free_meta_buffer(void *buffer) { CUDACHECK(cudaFree(buffer)); }

torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor &inp)
{
  auto options =
      torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
  auto data_handle =
      torch::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, options);
  CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t *)data_handle.data_ptr(),
                                inp.data_ptr()));
  return data_handle;
}

torch::Tensor allocate_meta_buffer(int64_t size)
{
  auto device_index = c10::cuda::current_device();
  at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
  void *buffer;
  cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
  auto stream = c10::cuda::getCurrentCUDAStream().stream();
  AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
  AT_CUDA_CHECK(
      hipExtMallocWithFlags((void **)&buffer, size, hipDeviceMallocUncached));
  AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
  AT_CUDA_CHECK(cudaStreamSynchronize(stream));
  AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
  auto options = torch::TensorOptions()
                     .dtype(torch::kI8)
                     .device(torch::kCUDA, device_index);
  return torch::from_blob(buffer, {size}, free_meta_buffer, options);
}

#endif

} // namespace aiter
