#pragma once #include /** * PyBind and PyTorch Library apis generally require different type signatures. * This file provides a shim to (mostly, there may be missing conversions) to * convert from function designed to be used with PyBind to one that can be used * with PyTorch Library. This is done using `make_pytorch_shim` which creates a * lambda that exponses the API using PyTorch compatible types to the types. * This is useful when trying to ingergate PyBind based external libraries into * vLLM. * * Example: * * PYBIND11_MODULE(NAME, m) { * m.def("foo", &foo); * } * * could be replaced with (using the shim): * TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { * m.def("foo", make_pytorch_shim(&foo)); * m.impl("foo", torch::kCUDA, make_pytorch_shim(&foo)); * } * * The `pytorch_library_compatible_type` struct is used to map from the * flash_attn ops types to a PyTorch library compatible one. The main issues is * that the following types are not support by PyTorch library bindings: * - `int` * - `float` * - `c10::optional &` * - `c10::optional &` * So we convert them to (respectively): * - `int64_t` * - `double` * - `const c10::optional&` * - `const c10::optional&` */ template struct pytorch_library_compatible_type { using type = T; static T convert_from_type(T arg) { return arg; } }; template using pytorch_library_compatible_type_t = typename pytorch_library_compatible_type::type; template T convert_from_pytorch_compatible_type( pytorch_library_compatible_type_t arg) { return pytorch_library_compatible_type::convert_from_type(arg); } // Map `c10::optional &` -> `const c10::optional&` // (NOTE: this is bit unsafe but non of the ops in flash_attn mutate // the optional container) template struct pytorch_library_compatible_type&> { using type = const c10::optional&; static c10::optional& convert_from_type(const c10::optional& arg) { return const_cast&>(arg); } }; // Map `c10::optional` -> // `c10::optional>` // (NOTE: tested for `c10::optional` -> `c10::optional`) template struct pytorch_library_compatible_type> { using type = c10::optional>; static c10::optional> convert_from_type( c10::optional arg) { return arg; } }; // Map `c10::optional&` -> `const c10::optional&` template <> struct pytorch_library_compatible_type&> { using type = const c10::optional&; static c10::optional& convert_from_type( const c10::optional& arg) { return const_cast&>( reinterpret_cast&>(arg)); } }; // Map `int` -> `int64_t` template <> struct pytorch_library_compatible_type { using type = int64_t; static int convert_from_type(int64_t arg) { TORCH_CHECK(arg <= std::numeric_limits::max(), "int64_t value is too large to be converted to int"); TORCH_CHECK(arg >= std::numeric_limits::min(), "int64_t value is too small to be converted to int"); return arg; } }; // Map `float` -> `double` template <> struct pytorch_library_compatible_type { using type = double; static float convert_from_type(double arg) { TORCH_CHECK(std::abs(arg) <= std::numeric_limits::max(), "double value is too large to be converted to float"); return arg; } }; // // Shim Utils // template auto make_pytorch_shim(Ret (*fun)(Args... args)) { return [fun](pytorch_library_compatible_type_t... args) { return fun(convert_from_pytorch_compatible_type(args)...); }; }