#include <c10/xpu/XPUStream.h>
#include <torch/extension.h>
#include <sycl/sycl.hpp>

void sigmoid_add_kernel(const float* x,
                        const float* y,
                        float* output,
                        const int size,
                        const sycl::nd_item<3> &item_ct1) {
    const int index = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
                      item_ct1.get_local_id(2);
    if (index < size) {
        const float sigmoid_x = 1.0f / (1.0f + sycl::native::exp(-x[index]));
        const float sigmoid_y = 1.0f / (1.0f + sycl::native::exp(-y[index]));
        output[index] = sigmoid_x + sigmoid_y;
    }
}

class SigmoidAddKernel {
public:
    void operator()(const sycl::nd_item<3> &item_ct1) const {
        sigmoid_add_kernel(x, y, output, size, item_ct1);
    }
    SigmoidAddKernel(const float* _x, const float* _y, float* _output, int _size):
        x(_x),
        y(_y),
        output(_output),
        size(_size)
    {}
private:
    const float* x;
    const float* y;
    float* output;
    int size;
};

void sigmoid_add_xpu(const float* x, const float* y, float* output, int size) {
    SigmoidAddKernel krn(x, y, output, size);
    const int threads = 1024;
    const int blocks = (size + threads - 1) / threads;

    sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
    queue.submit([&](sycl::handler &cgh) {
        cgh.parallel_for<SigmoidAddKernel>(
            sycl::nd_range<3>(
                sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, threads),
                sycl::range<3>(1, 1, threads)),
            krn);
        });
}

torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
  TORCH_CHECK(x.device().is_xpu(), "x must be a XPU tensor");
  TORCH_CHECK(y.device().is_xpu(), "y must be a XPU tensor");
  auto output = torch::zeros_like(x);
  sigmoid_add_xpu(
      x.data_ptr<float>(), y.data_ptr<float>(), output.data_ptr<float>(), output.numel());
  return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)");
}
