#pragma once #include #include #include #include #include #include #include #include namespace c10 { namespace impl { // Take a DispatchKeySet for a Tensor and determine what the actual dispatch // DispatchKey should be, taking into account TLS, and skipping backends which // fall through. // // Unlike Tensor::key_set(), the value of this on a tensor can change depending // on TLS. // // NB: If there is no valid dispatch key, this will return Undefined static inline DispatchKeySet computeDispatchKeySet( DispatchKeySet ks, // The key mask lets us eliminate (by zero entries) keys which should not // be considered for dispatch. There are two cases when we use this: // // - If an operator's dispatch table contains a fallthrough entry, we // should bypass it entirely when finding the key // - If a user invokes with redispatch, the mask lets us // zero out the key the user asked us to stop. // // These excluded backends are NOT tracked in the TLS, but must be applied // AFTER TLS (since the backend may have been introduced for consideration // by the included TLS), which is why you have to pass them in to this // function (as opposed to just applying it to the input 'ks'). DispatchKeySet key_mask ) { c10::impl::LocalDispatchKeySet local = c10::impl::tls_local_dispatch_key_set(); // TODO: It's a bit irritating that we have to do logical ORs here, it would // be nice to only do one. Can always_included be folded into the TLS? Well, // it's a bit troublesome, because fastpath TLS access requires the type of // the TLS in question to be zero-initialized, so you don't actually win // anyting in that case. return (((ks | local.included_) - local.excluded_) & key_mask); } } namespace detail { // A small gadget to extract the DispatchKeySet from types which are known // to have it. Used to extract dispatch keys from unboxed calls. struct MultiDispatchKeySet : at::IterArgs { DispatchKeySet ts; void operator()(const at::Tensor& x) { ts = ts | x.key_set(); } void operator()(const c10::optional& x) { if (x.has_value()) { ts = ts | x->key_set(); } } void operator()(at::ArrayRef xs) { for (const auto& x : xs) { ts = ts | x.key_set(); } } // Tensor?[] translates to this case. void operator()(const c10::List>& xs) { for (c10::optional x : xs) { if (x.has_value()) { ts = ts | x.value().key_set(); } } } // Structured Tensor[] translates to this case void operator()(at::ITensorListRef xs) { for (const auto& x : xs) { ts = ts | x.key_set(); } } void operator()(at::ArrayRef>) { // Just checking that the handling of Tensor?[] didn't change. TORCH_INTERNAL_ASSERT(false); } void operator()(const at::Generator& gen) { if (gen.defined()) { ts = ts | gen.key_set(); } } void operator()(const c10::optional& gen) { if (gen.has_value() && gen->defined()) { ts = ts | gen->key_set(); } } template void operator()(const T&) { // do nothing } }; // NB: take by const reference (Don't do universal forwarding here! You // don't want to move into this function!) template DispatchKeySet multi_dispatch_key_set(const Args&... args) { return MultiDispatchKeySet().apply(args...).ts; } } /** * An instance of DispatchKeyExtractor knows how to get a dispatch key given * a list of arguments for an operator call. * * The instance is specific for a certain operator as: * - In boxed dispatch, different operators have different ways to extract * the dispatch key (e.g. different numbers of arguments), and we precompute * the stack locations we should look at; and * - In all dispatch, some backends should be excluded from dispatch because * they have been registered as fallthrough. The set of excluded backends * varies from operator, as some operators may have overridden the * fallthrough with custom behavior. * * Note - this should maintain identical impl to the py dispatcher key extraction logic * at pytorch/torch/dispatcher.py */ struct TORCH_API DispatchKeyExtractor final { public: static DispatchKeyExtractor make(const FunctionSchema& schema) { return DispatchKeyExtractor(makeBitsetForDispatchArgs(schema)); } static DispatchKeyExtractor makeUninitialized() { return DispatchKeyExtractor(c10::utils::bitset()); } void registerSchema(const FunctionSchema& schema) { TORCH_INTERNAL_ASSERT(dispatch_arg_indices_reverse_.is_entirely_unset()); dispatch_arg_indices_reverse_ = makeBitsetForDispatchArgs(schema); } void deregisterSchema() { dispatch_arg_indices_reverse_ = c10::utils::bitset(); } DispatchKeySet getDispatchKeySetBoxed(const torch::jit::Stack* stack) const { DispatchKeySet ks; dispatch_arg_indices_reverse_.for_each_set_bit([&] (size_t reverse_arg_index) { const auto& ivalue = torch::jit::peek(*stack, 0, reverse_arg_index + 1); if (C10_LIKELY(ivalue.isTensor())) { // NB: Take care not to introduce a refcount bump (there's // no safe toTensorRef method, alas) ks = ks | ivalue.unsafeToTensorImpl()->key_set(); } else if (C10_UNLIKELY(ivalue.isTensorList())) { for (const at::Tensor& tensor : ivalue.toTensorList()) { ks = ks | tensor.key_set(); } } // Tensor?[] translates to a c10::List so we need to peek inside else if (C10_UNLIKELY(ivalue.isList())) { for (const auto& elt : ivalue.toListRef()) { if (elt.isTensor()) { ks = ks | elt.toTensor().key_set(); } } } }); // Keys that are fallthrough should be skipped if (requiresBitsetPerBackend_) { auto backend_idx = ks.getBackendIndex(); return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]); } else { return impl::computeDispatchKeySet(ks, nonFallthroughKeys_); } } template DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const { auto ks = detail::multi_dispatch_key_set(args...); // Keys that are fallthrough should be skipped if (requiresBitsetPerBackend_) { auto backend_idx = ks.getBackendIndex(); return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]); } else { return impl::computeDispatchKeySet(ks, nonFallthroughKeys_); } } void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough); std::string dumpState() const; void checkInvariants(const FunctionSchema& schema) const; private: static c10::utils::bitset makeBitsetForDispatchArgs(const FunctionSchema& schema) { TORCH_CHECK(schema.arguments().size() <= c10::utils::bitset::NUM_BITS(), "The function schema has ", schema.arguments().size(), " arguments but this PyTorch build only supports ", c10::utils::bitset::NUM_BITS()); c10::utils::bitset dispatch_arg_indices_reverse; for (const auto index : c10::irange(schema.arguments().size())) { if (schema.arguments()[index].type()->isSubtypeOf(*TensorType::get()) || schema.arguments()[index].type()->isSubtypeOf( *ListType::ofTensors()) || schema.arguments()[index].type()->isSubtypeOf( *ListType::ofOptionalTensors()) || schema.arguments()[index].type()->isSubtypeOf( *OptionalType::ofTensor())) { dispatch_arg_indices_reverse.set(schema.arguments().size() - 1 - index); } } return dispatch_arg_indices_reverse; } explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse) : dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse) , nonFallthroughKeys_(DispatchKeySet::FULL) , requiresBitsetPerBackend_(false) { for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) { nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL; } } // this is a bitset that has ones for each argument index which has to be // considered for dispatch. This avoids having to iterate over the stack // to find all the tensors. The bits are stored in reverse order, i.e. // dispatch_arg_indices_reverse_[i] == true, then the i-th argument from // the top of the stack (i.e. the i-th last argument of the function) // is relevant for dispatch. // dispatch_arg_indices_reverse_ is allowed to have zero bits set; that just means you must do the // fallthrough c10::utils::bitset dispatch_arg_indices_reverse_; // Set of functionality keys for which the operator does NOT have fallthrough kernel. DispatchKeySet nonFallthroughKeys_; // Set of functionality keys for which the operator does NOT have fallthrough kernel, defined PER BACKEND. // This is only needed if we know that the operator has a different set of fallthroughs defined for some backends. std::array nonFallthroughKeysPerBackend_; // Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast path), // or if we need to fall back to the slower path and check nonFallthroughKeysPerBackend_ bool requiresBitsetPerBackend_; }; }