#include #include #include #include namespace c10 { inline KernelFunction::KernelFunction() : boxed_kernel_func_() , unboxed_kernel_func_(nullptr) , sym_unboxed_kernel_func_(nullptr) {} inline KernelFunction::KernelFunction(std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr) : boxed_kernel_func_(std::move(functor), boxed_kernel_func) , unboxed_kernel_func_(unboxed_kernel_func) , sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {} inline KernelFunction::KernelFunction(BoxedKernel boxed_fn, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr) : boxed_kernel_func_(std::move(boxed_fn)) , unboxed_kernel_func_(unboxed_kernel_func) , sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {} inline bool KernelFunction::isValidUnboxed() const { return unboxed_kernel_func_ != nullptr; } inline bool KernelFunction::isValidSymUnboxed() const { return sym_unboxed_kernel_func_ != nullptr; } inline bool KernelFunction::isValid() const { return boxed_kernel_func_.isValid(); } inline bool KernelFunction::isFallthrough() const { return boxed_kernel_func_.isFallthrough(); } inline void KernelFunction::callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const { boxed_kernel_func_.callBoxed(opHandle, dispatchKeySet, stack); } template inline Return callUnboxedKernelFunction(void* unboxed_kernel_func, OperatorKernel* functor, DispatchKeySet dispatchKeySet, Args&&... args) { using ActualSignature = Return (OperatorKernel*, DispatchKeySet, Args...); ActualSignature* func = reinterpret_cast(unboxed_kernel_func); return (*func)(functor, dispatchKeySet, std::forward(args)...); } // This template requires you to explicitly specify the argument you want to // forward; it doesn't work if you try to deduce it // NB: keep this in sync with cloneWithRealTypes in function_schema.cpp template inline typename remove_symint::type unpackSymInt(T x) { return x; } template <> inline typename remove_symint::type unpackSymInt(c10::SymInt x) { return x.expect_int(); } template <> inline typename remove_symint::type unpackSymInt(c10::SymIntArrayRef x) { return c10::asIntArrayRefSlow(x); } template <> inline typename remove_symint>::type unpackSymInt(c10::optional x) { return x.has_value() ? c10::make_optional(x->expect_int()) : c10::nullopt; } template <> inline typename remove_symint::type unpackSymInt(at::OptionalSymIntArrayRef x) { return x.has_value() ? c10::make_optional(c10::asIntArrayRefSlow(*x)) : c10::nullopt; } template C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const { // note: Args above is intentionally not Args&&. We don't want perfect // forwarding, which would require Args to be deduced, but instead we // want callers to explicitly specify the Args. // This should get inlined by compiler if (guts::disjunction...>::value) { if (sym_unboxed_kernel_func_ != nullptr) { auto *functor = boxed_kernel_func_.getFunctor(); return callUnboxedKernelFunction( sym_unboxed_kernel_func_, functor, dispatchKeySet, std::forward(args)...); } if (unboxed_kernel_func_ != nullptr) { auto *functor = boxed_kernel_func_.getFunctor(); return callUnboxedKernelFunction::type...>( unboxed_kernel_func_, functor, dispatchKeySet, unpackSymInt(args)...); } } else { if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) { auto *functor = boxed_kernel_func_.getFunctor(); return callUnboxedKernelFunction( unboxed_kernel_func_, functor, dispatchKeySet, std::forward(args)...); } } return impl::BoxedKernelWrapper::call( boxed_kernel_func_, opHandle, dispatchKeySet, std::forward(args)... ); } inline KernelFunction KernelFunction::makeFromBoxedKernel(BoxedKernel boxed_fn) { return KernelFunction(std::move(boxed_fn), nullptr); // no unboxed function pointer } template inline KernelFunction KernelFunction::makeFromBoxedFunction() { return KernelFunction::makeFromBoxedKernel( BoxedKernel::makeFromFunction()); } template inline KernelFunction KernelFunction::makeFromBoxedFunction() { return KernelFunction::makeFromBoxedKernel( BoxedKernel::makeFromFunction()); } inline KernelFunction KernelFunction::makeFallthrough() { return KernelFunction::makeFromBoxedKernel( BoxedKernel::makeFallthrough()); } inline KernelFunction KernelFunction::makeAmbiguousAutogradOther() { return KernelFunction::makeFromBoxedKernel( BoxedKernel::makeAmbiguousAutogradOther()); } inline KernelFunction KernelFunction::makeNamedNotSupported() { return KernelFunction::makeFromBoxedKernel( BoxedKernel::makeNamedNotSupported()); } template inline KernelFunction KernelFunction::makeFromUnboxedFunctor(std::unique_ptr kernelFunctor) { #ifndef NDEBUG // This assertion is costly for build time so it's debug-gated. static_assert(guts::is_functor::value, "Tried to call KernelFunction::makeFromUnboxedFunctor but the argument is not a functor."); #endif static_assert(std::is_base_of::value, "Tried to call KernelFunction::makeFromUnboxedFunctor, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); auto* unboxed_fn = &impl::wrap_kernel_functor_unboxed::call; void* void_unboxed_fn = reinterpret_cast(unboxed_fn); bool is_symint = fn_has_symint::value; return KernelFunction( std::move(kernelFunctor), &impl::make_boxed_from_unboxed_functor::call, is_symint ? nullptr : void_unboxed_fn, is_symint ? void_unboxed_fn : nullptr ); } template inline KernelFunction KernelFunction::makeFromBoxedFunctor(std::unique_ptr kernelFunctor) { return KernelFunction::makeFromBoxedKernel( BoxedKernel::makeFromFunctor(std::move(kernelFunctor))); } template inline KernelFunction KernelFunction::makeFromUnboxedFunction(FuncPtr func_ptr) { static_assert(is_compile_time_function_pointer::value, "Tried to call KernelFunction::makeFromUnboxedFunction with an invalid parameter. It must be a function pointer created with TORCH_FN."); static_assert(!std::is_same::value, "Tried to call KernelFunction::makeFromUnboxedFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead."); static_assert(FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr"); #if !defined(C10_MOBILE) (void)func_ptr; // Suppress unused variable warning return makeFromUnboxedFunctor::type>( guts::make_unique_base::type>() ); #else // On mobile, we rather want to optimize for binary size than for performance, // so let's not inline the kernel into the wrapper but use makeFromUnboxedRuntimeFunction // instead. return makeFromUnboxedRuntimeFunction(func_ptr.func_ptr()); #endif } template inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction(FuncType* func) { static_assert(guts::is_function_type::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type."); static_assert(!std::is_same::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead."); TORCH_INTERNAL_ASSERT(func != nullptr, "Kernel function cannot be nullptr"); return makeFromUnboxedFunctor>>( guts::make_unique_base>>(func) ); } template inline std::enable_if_t>::value, KernelFunction> KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) { static_assert(guts::is_functor>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type."); #if !defined(C10_MOBILE) return makeFromUnboxedFunctor>>( guts::make_unique_base>>(std::forward(lambda)) ); #else // On mobile, we rather want to optimize for binary size than for performance, // so let's not inline the kernel into the wrapper but use makeFromUnboxedRuntimeFunction // instead. using FuncType = typename guts::infer_function_traits_t>::func_type; return makeFromUnboxedRuntimeFunction(lambda); #endif } template inline std::enable_if_t>::value, KernelFunction> KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) { static_assert(guts::is_functor>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type."); return makeFromUnboxedFunctor>>( guts::make_unique_base>>(std::forward(lambda)) ); } }