#pragma once /** * This file contains functionality to take a C++ function and infer its * c10::FunctionSchema. */ #include #include #include namespace c10 { namespace detail { namespace infer_schema { /// The templated inference code creates `ArgumentDef` instead of `Argument`, /// because that can be constructed at compile time and has a much smaller /// binary size than having calls to `Argument` constructors in the template. /// Creating `Argument` objects from `ArgumentDef` can then be done at /// runtime in a non-templated way. struct ArgumentDef final { using GetTypeFn = TypePtr(); GetTypeFn* getTypeFn; GetTypeFn* getFakeTypeFn; constexpr ArgumentDef(): getTypeFn(nullptr), getFakeTypeFn(nullptr) {} explicit constexpr ArgumentDef(GetTypeFn *getTypeFn, GetTypeFn *getFakeTypeFn): getTypeFn(getTypeFn), getFakeTypeFn(getFakeTypeFn) {} }; template struct bool_t {}; template<> struct bool_t : std::true_type {}; template<> struct bool_t : std::false_type {}; /// Checks the static C++ types `Types` for correctness to catch common error cases. template constexpr int checkStaticTypes() { // Give nice error messages for some of the common error cases. // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT static_assert(guts::conjunction< bool_t::value || std::is_same::value || std::is_same::value>... >::value, "INVALID TYPE: Only int64_t and bool are supported as an integral argument type"); static_assert(guts::conjunction< bool_t::value>... >::value, "INVALID TYPE: float is not supported as an argument type, use double instead"); return 0; } template constexpr std::array createArgumentVectorFromTypes(std::index_sequence) { return ( // Check types for common errors checkStaticTypes(), // Create the return value std::array{ ArgumentDef(&getTypePtrCopy>, &getFakeTypePtrCopy>)...} ); } /// Creates a vector of `ArgumentDef` from a list of C++ types that are specified /// as template arguments. template struct createArguments final {}; template struct createArguments> final { static constexpr std::array call() { return createArgumentVectorFromTypes( std::make_index_sequence() ); } }; /// Creates a vector of `ArgumentDef` from a list of C++ types that are specified /// as a tuple (i.e. in the way c10 kernels return values). /// It can be a tuple if there's three output arguments with types A, B, C. /// It can be an empty tuple<>, or void for kernels that don't return anything. /// It can be a single type A (i.e. no tuple) for the case where a kernel just /// returns one value. template struct createReturns final {}; template struct createReturns, void> final { static constexpr std::array call() { return createArgumentVectorFromTypes( std::make_index_sequence() ); } }; template struct createReturns::value && !guts::is_instantiation_of::value>> final { static constexpr std::array call() { return createReturns>::call(); } }; template<> struct createReturns final { static constexpr std::array call() { return createReturns>::call(); } }; template struct createSingleReturn { static constexpr std::array call() { return createArgumentVectorFromTypes(std::make_index_sequence<1>()); } }; C10_API FunctionSchema make_function_schema(std::string&& name, std::string&& overload_name, c10::ArrayRef arguments, c10::ArrayRef returns); C10_API FunctionSchema make_function_schema(c10::ArrayRef arguments, c10::ArrayRef returns); /// Creates a `FunctionSchema` object from a `FunctionTraits` type for a /// function. Flattens std::tuple returns into multiple return types template FunctionSchema createFunctionSchemaFromTraitsFlattenedReturns() { using ReturnType = typename FunctionTraits::return_type; using ParameterTypes = typename FunctionTraits::parameter_types; // arguments and returns are computed into a std::array at compile time and embedded into the binary. // The only code executed at runtime here is the one that creates a std::vector // of the arguments/returns from the std::array. constexpr auto arguments = createArguments::call(); constexpr auto returns = createReturns::call(); return make_function_schema(arguments, returns); } /// Creates a `FunctionSchema` object from a `FunctionTraits` type for a /// function. Preserves std::tuple returns as a Tuple return type template FunctionSchema createFunctionSchemaFromTraitsSingleReturn(std::string&& name, std::string&& overload_name) { using ReturnType = typename FunctionTraits::return_type; using ParameterTypes = typename FunctionTraits::parameter_types; // arguments and returns are computed into a std::array at compile time and embedded into the binary. // The only code executed at runtime here is the one that creates a std::vector // of the arguments/returns from the std::array. constexpr auto arguments = createArguments::call(); constexpr auto returns = createSingleReturn::call(); return make_function_schema(std::move(name), std::move(overload_name), arguments, returns); } } } template FunctionSchema inferFunctionSchemaFlattenedReturns() { return detail::infer_schema::createFunctionSchemaFromTraitsFlattenedReturns>(); } template FunctionSchema inferFunctionSchemaSingleReturn(std::string&& name, std::string&& overload_name) { return detail::infer_schema::createFunctionSchemaFromTraitsSingleReturn>(std::move(name), std::move(overload_name)); } TORCH_API c10::optional findSchemaDifferences(const FunctionSchema& inferred, const FunctionSchema& specified); }