#include #include #include #include #include #include namespace c10 { namespace detail { template std::enable_if_t< !std::is_array_v && !std::is_array_v && std::is_base_of_v, std::unique_ptr> make_unique_base(Args&&... args) { return std::unique_ptr(new Child(std::forward(args)...)); } } // namespace detail inline KernelFunction::KernelFunction() : boxed_kernel_func_(), unboxed_kernel_func_(nullptr), sym_unboxed_kernel_func_(nullptr) {} inline KernelFunction::KernelFunction( std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr) : boxed_kernel_func_(std::move(functor), boxed_kernel_func), unboxed_kernel_func_(unboxed_kernel_func), sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {} inline KernelFunction::KernelFunction( BoxedKernel boxed_fn, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr) : boxed_kernel_func_(std::move(boxed_fn)), unboxed_kernel_func_(unboxed_kernel_func), sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {} inline bool KernelFunction::isValidUnboxed() const { return unboxed_kernel_func_ != nullptr; } inline bool KernelFunction::isValidSymUnboxed() const { return sym_unboxed_kernel_func_ != nullptr; } inline bool KernelFunction::isValid() const { return boxed_kernel_func_.isValid(); } inline bool KernelFunction::isFallthrough() const { return boxed_kernel_func_.isFallthrough(); } inline void KernelFunction::callBoxed( const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const { boxed_kernel_func_.callBoxed(opHandle, dispatchKeySet, stack); } template inline Return callUnboxedKernelFunction( void* unboxed_kernel_func, OperatorKernel* functor, DispatchKeySet dispatchKeySet, Args&&... args) { using ActualSignature = Return(OperatorKernel*, DispatchKeySet, Args...); ActualSignature* func = reinterpret_cast(unboxed_kernel_func); return (*func)(functor, dispatchKeySet, std::forward(args)...); } // This template requires you to explicitly specify the argument you want to // forward; it doesn't work if you try to deduce it // NB: keep this in sync with cloneWithRealTypes in function_schema.cpp template inline typename remove_symint::type unpackSymInt(T x) { return x; } template <> inline typename remove_symint::type unpackSymInt(c10::SymInt x) { return x.guard_int(__FILE__, __LINE__); } template <> inline typename remove_symint::type unpackSymInt( c10::SymIntArrayRef x) { return C10_AS_INTARRAYREF_SLOW(x); } template <> inline typename remove_symint>::type unpackSymInt( std::optional x) { return x.has_value() ? std::make_optional(x->guard_int(__FILE__, __LINE__)) : std::nullopt; } template <> inline typename remove_symint::type unpackSymInt( at::OptionalSymIntArrayRef x) { return x.has_value() ? std::make_optional(C10_AS_INTARRAYREF_SLOW(*x)) : std::nullopt; } template C10_ALWAYS_INLINE Return KernelFunction::call( const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const { // note: Args above is intentionally not Args&&. We don't want perfect // forwarding, which would require Args to be deduced, but instead we // want callers to explicitly specify the Args. if constexpr (std::disjunction_v...>) { if (sym_unboxed_kernel_func_ != nullptr) { auto* functor = boxed_kernel_func_.getFunctor(); return callUnboxedKernelFunction( sym_unboxed_kernel_func_, functor, dispatchKeySet, std::forward(args)...); } if (unboxed_kernel_func_ != nullptr) { auto* functor = boxed_kernel_func_.getFunctor(); return callUnboxedKernelFunction< Return, typename remove_symint::type...>( unboxed_kernel_func_, functor, dispatchKeySet, unpackSymInt(args)...); } } else { if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) { auto* functor = boxed_kernel_func_.getFunctor(); return callUnboxedKernelFunction( unboxed_kernel_func_, functor, dispatchKeySet, std::forward(args)...); } } return impl::BoxedKernelWrapper::call( boxed_kernel_func_, opHandle, dispatchKeySet, std::forward(args)...); } inline KernelFunction KernelFunction::makeFromBoxedKernel( BoxedKernel boxed_fn) { return KernelFunction( std::move(boxed_fn), nullptr); // no unboxed function pointer } template inline KernelFunction KernelFunction::makeFromBoxedFunction() { return KernelFunction::makeFromBoxedKernel( BoxedKernel::makeFromFunction()); } template inline KernelFunction KernelFunction::makeFromBoxedFunction() { return KernelFunction::makeFromBoxedKernel( BoxedKernel::makeFromFunction()); } inline KernelFunction KernelFunction::makeFallthrough() { return KernelFunction::makeFromBoxedKernel(BoxedKernel::makeFallthrough()); } inline KernelFunction KernelFunction::makeAmbiguousAutogradOther() { return KernelFunction::makeFromBoxedKernel( BoxedKernel::makeAmbiguousAutogradOther()); } inline KernelFunction KernelFunction::makeNamedNotSupported() { return KernelFunction::makeFromBoxedKernel( BoxedKernel::makeNamedNotSupported()); } template inline KernelFunction KernelFunction::makeFromUnboxedFunctor( std::unique_ptr kernelFunctor) { #ifndef NDEBUG // This assertion is costly for build time so it's debug-gated. static_assert( guts::is_functor::value, "Tried to call KernelFunction::makeFromUnboxedFunctor but the argument is not a functor."); #endif static_assert( std::is_base_of_v, "Tried to call KernelFunction::makeFromUnboxedFunctor, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); auto* unboxed_fn = &impl::wrap_kernel_functor_unboxed::call; void* void_unboxed_fn = reinterpret_cast(unboxed_fn); bool is_symint = fn_has_symint::value; return KernelFunction( std::move(kernelFunctor), &impl::make_boxed_from_unboxed_functor:: call, is_symint ? nullptr : void_unboxed_fn, is_symint ? void_unboxed_fn : nullptr); } template inline KernelFunction KernelFunction::makeFromBoxedFunctor( std::unique_ptr kernelFunctor) { return KernelFunction::makeFromBoxedKernel( BoxedKernel::makeFromFunctor(std::move(kernelFunctor))); } template inline KernelFunction KernelFunction::makeFromUnboxedFunction( FuncPtr func_ptr) { static_assert( is_compile_time_function_pointer::value, "Tried to call KernelFunction::makeFromUnboxedFunction with an invalid parameter. It must be a function pointer created with TORCH_FN."); static_assert( !std::is_same_v, "Tried to call KernelFunction::makeFromUnboxedFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead."); #if defined(__GNUC__) && defined(__SANITIZE_ADDRESS__) && !defined(__CUDACC__) TORCH_INTERNAL_ASSERT( FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr"); #else static_assert( FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr"); #endif #if !defined(C10_MOBILE) (void)func_ptr; // Suppress unused variable warning return makeFromUnboxedFunctor< AllowLegacyTypes, typename impl::WrapFunctionIntoFunctor::type>( detail::make_unique_base< OperatorKernel, typename impl::WrapFunctionIntoFunctor::type>()); #else // On mobile, we rather want to optimize for binary size than for performance, // so let's not inline the kernel into the wrapper but use // makeFromUnboxedRuntimeFunction instead. return makeFromUnboxedRuntimeFunction(func_ptr.func_ptr()); #endif } template inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction( FuncType* func) { static_assert( guts::is_function_type::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type."); static_assert( !std::is_same_v, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead."); TORCH_INTERNAL_ASSERT(func != nullptr, "Kernel function cannot be nullptr"); return makeFromUnboxedFunctor< AllowLegacyTypes, impl::WrapFunctionIntoRuntimeFunctor>>( detail::make_unique_base< OperatorKernel, impl::WrapFunctionIntoRuntimeFunctor>>(func)); } template inline std::enable_if_t< guts::is_stateless_lambda>::value, KernelFunction> KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) { static_assert( guts::is_functor>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type."); #if !defined(C10_MOBILE) return makeFromUnboxedFunctor< AllowLegacyTypes, impl::WrapFunctionIntoRuntimeFunctor>>( detail::make_unique_base< OperatorKernel, impl::WrapFunctionIntoRuntimeFunctor>>( std::forward(lambda))); #else // On mobile, we rather want to optimize for binary size than for performance, // so let's not inline the kernel into the wrapper but use // makeFromUnboxedRuntimeFunction instead. using FuncType = typename guts::infer_function_traits_t>::func_type; return makeFromUnboxedRuntimeFunction(lambda); #endif } template inline std::enable_if_t< !guts::is_stateless_lambda>::value, KernelFunction> KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) { static_assert( guts::is_functor>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type."); return makeFromUnboxedFunctor< AllowLegacyTypes, impl::WrapFunctionIntoRuntimeFunctor>>( detail::make_unique_base< OperatorKernel, impl::WrapFunctionIntoRuntimeFunctor>>( std::forward(lambda))); } } // namespace c10