#pragma once // Wrap tensor operation outputs as PyObject* #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::autograd::utils { inline PyObject* wrap(bool value) { if (value) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; } } inline PyObject* wrap(c10::DeviceIndex value) { return THPUtils_packDeviceIndex(value); } inline PyObject* wrap(int64_t value) { return THPUtils_packInt64(value); } inline PyObject* wrap(double value) { return PyFloat_FromDouble(value); } inline PyObject* wrap(c10::complex value) { // I could probably also use FromComplex with a reinterpret cast, // but... eh. return PyComplex_FromDoubles(value.real(), value.imag()); } inline PyObject* wrap(void* value) { return PyLong_FromVoidPtr(value); } inline PyObject* wrap(THPDtype* dtype) { return Py_NewRef(dtype); } inline PyObject* wrap(at::ScalarType scalarType) { return Py_NewRef(getTHPDtype(scalarType)); } inline PyObject* wrap(THPLayout* layout) { return Py_NewRef(layout); } inline PyObject* wrap(at::Layout layout) { return Py_NewRef(getTHPLayout(layout)); } inline PyObject* wrap(const at::Tensor& tensor) { return THPVariable_Wrap(tensor); } inline PyObject* wrap(const at::Scalar& scalar) { return wrap(scalar_to_tensor(scalar)); } inline PyObject* wrap(at::QScheme qscheme) { auto* thp_qscheme = torch::utils::getTHPQScheme(qscheme); Py_INCREF(thp_qscheme); return thp_qscheme; } inline PyObject* wrap(at::TensorList tl) { auto r = THPObjectPtr{PyTuple_New(static_cast(tl.size()))}; if (!r) throw python_error(); for (const auto i : c10::irange(tl.size())) { PyTuple_SET_ITEM(r.get(), i, wrap(tl[i])); } return r.release(); } inline PyObject* wrap(at::IntArrayRef list) { auto r = THPObjectPtr{PyTuple_New(static_cast(list.size()))}; if (!r) throw python_error(); for (const auto i : c10::irange(list.size())) { PyTuple_SET_ITEM(r.get(), i, wrap(list[i])); } return r.release(); } inline PyObject* wrap(at::Stream stream) { return THPStream_Wrap(stream); } namespace detail { template void apply_with_idx_impl( const F& f, Tuple& t, std::index_sequence /*indices*/) { (void)std::initializer_list{(f(std::get(t), Is), 0)...}; } // For tuple(a, b, c), calls f(a, 0), f(b, 1), f(c, 2) template void apply_with_idx(const F& f, std::tuple& t) { apply_with_idx_impl(f, t, std::index_sequence_for{}); } } // namespace detail template PyObject* wrap(std::tuple values) { auto r = THPObjectPtr{PyTuple_New(sizeof...(Ts))}; if (!r) throw python_error(); detail::apply_with_idx( [&](auto& value, size_t idx) { PyTuple_SET_ITEM(r.get(), idx, wrap(std::move(value))); }, values); return r.release(); } template PyObject* wrap(PyTypeObject* type, std::tuple values) { auto r = THPObjectPtr{PyStructSequence_New(type)}; if (!r) throw python_error(); detail::apply_with_idx( [&](auto& value, size_t idx) { PyStructSequence_SET_ITEM(r.get(), idx, wrap(std::move(value))); }, values); return r.release(); } } // namespace torch::autograd::utils