#pragma once #include #include #include #include #include #include namespace c10 { using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace. class OperatorHandle; struct OperatorKernel; class KernelFunction; template using has_symint = guts::disjunction< std::is_same>, std::is_same>, std::is_same>, std::is_same, std::decay_t> >; template struct remove_symint { using type = T; }; template <> struct remove_symint { using type = int64_t; }; template <> struct remove_symint { using type = OptionalIntArrayRef; }; template <> struct remove_symint { using type = c10::IntArrayRef; }; template <> struct remove_symint> { using type = c10::optional; }; template struct maybe_keep_symint final {}; template struct maybe_keep_symint { using type = T; }; template struct maybe_keep_symint { using type = typename remove_symint::type; }; template using fn_has_symint = typename guts::typelist::true_for_any_type< has_symint, typename guts::infer_function_traits::type::parameter_types >; /** * KernelFunction is similar to std::function but stores a kernel function. * You can create a KernelFunction from a boxed or unboxed function/functor/lambda * and call it in a boxed or unboxed way. If the way it was created doesn't * match the way it was called, it will do boxing or unboxing as necessary. */ class TORCH_API KernelFunction final { public: using InternalBoxedKernelFunction = BoxedKernel::InternalBoxedKernelFunction; using BoxedKernelFunction = BoxedKernel::BoxedKernelFunction; using BoxedKernelFunction_withDispatchKeys = BoxedKernel::BoxedKernelFunction_withDispatchKeys; KernelFunction(); // Fast path for dispatch to allow not touching the boxed kernel in // the common case where unboxed is available. bool isValidUnboxed() const; bool isValidSymUnboxed() const; bool isValid() const; bool isFallthrough() const; /** * Call the function in a boxed way. * If the kernel function was created with an unboxed function, * this will call an unboxing wrapper which then calls into that * unboxed function. * * Example: * * > void boxed_func(OperatorKernel*, Stack* stack) {...} * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func); * > Tensor result = func.callBoxed(stack); * * Or, with an unboxed implementation: * * > KernelFunction func = KernelFunction::makeFromUnboxedLambda( * > [] (Tensor a, bool b) -> Tensor {...}); * > Tensor result = func.callBoxed(stack); */ void callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const; /** * Call the function in an unboxed way. * If the kernel function was created with a boxed function, * this will box all inputs and then call into that boxed function. * * Note that this doesn't work for all types yet. * * Example: * * > KernelFunction func = KernelFunction::makeFromUnboxedLambda( * > [] (Tensor a, bool b) -> Tensor {...}); * > Tensor result = func.call(tensor1, true); * * Or, with a boxed implementation: * * > void boxed_func(OperatorKernel*, Stack* stack) {...} * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func); * > Tensor result = func.call(tensor1, true); */ template Return call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const; /** * Create a KernelFunction from a BoxedKernel. */ static KernelFunction makeFromBoxedKernel(BoxedKernel boxed_fn); /** * Create a KernelFunction from a boxed function. * * Example: * * > void boxed_func(OperatorKernel*, Stack* stack) {...} * > KernelFunction func = KernelFunction::makeFromBoxedFunction<&boxed_func>(); */ template static KernelFunction makeFromBoxedFunction(); /** * TODO: This will only be useful if we write a backend fallback that plumbs dispatch keys (currently there are none) * See Note [Plumbing Keys Through The Dispatcher] for details. */ template static KernelFunction makeFromBoxedFunction(); /** * Create a KernelFunction from an unboxed functor. * * Example: * * > class MyFunctor final : public c10::OperatorKernel { * > public: * > Tensor operator()(Tensor a, Tensor b) {...} * > }; * > KernelFunction func = KernelFunction::makeFromUnboxedFunctor(std::make_unique()); */ template static KernelFunction makeFromUnboxedFunctor(std::unique_ptr kernelFunctor); /** * Create a KernelFunction from a boxed functor. * * Example: * * > class MyFunctor final : public c10::OperatorKernel { * > public: * > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...} * > }; * > KernelFunction func = KernelFunction::makeFromBoxedFunctor(std::make_unique()); */ template static KernelFunction makeFromBoxedFunctor(std::unique_ptr kernelFunctor); /** * Create a KernelFunction from an unboxed function. * This is usually better than KernelFunction::makeFromUnboxedRuntimeFunction * because knowing the function pointer as a template argument (i.e. at * compile time) allows the compiler to inline the function into its * unboxing wrapper and yields better performance when calling the function. * * Example: * * > Tensor unboxed_func(Tensor a, Tensor b) {...} * > KernelFunction func = KernelFunction::makeFromUnboxedFunction(); */ template static KernelFunction makeFromUnboxedFunction(FuncPtr); /** * Create a KernelFunction from an unboxed function. * KernelFunction::makeFromUnboxedFunction is usually a better choice than * this if you know the function pointer at compile time, see doc comment * there for an explanation. * * Example: * * > Tensor unboxed_func(Tensor a, Tensor b) {...} * > KernelFunction func = KernelFunction::makeFromUnboxedRuntimeFunction(&unboxed_func); */ template static KernelFunction makeFromUnboxedRuntimeFunction(FuncType* func); static KernelFunction makeFallthrough(); static KernelFunction makeAmbiguousAutogradOther(); static KernelFunction makeNamedNotSupported(); /** * Create a KernelFunction from an unboxed lambda. * * Example: * * > KernelFunction func = KernelFunction::makeFromUnboxedLambda( * > [] (Tensor a, bool b) -> Tensor {...}); */ template static std::enable_if_t>::value, KernelFunction> makeFromUnboxedLambda(Lambda&& lambda); template static std::enable_if_t>::value, KernelFunction> makeFromUnboxedLambda(Lambda&& lambda); std::string dumpState() const; // For testing internal invariants only bool _equalsBoxedAndUnboxed(const KernelFunction&) const; private: explicit KernelFunction( std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func, void* sym_unboxed_kernel_func); explicit KernelFunction( BoxedKernel boxed_fn, void* unboxed_kernel_func, void* sym_unboxed_kernel_func); BoxedKernel boxed_kernel_func_; void* unboxed_kernel_func_; void* sym_unboxed_kernel_func_; }; } #include