#pragma once #include #include #include namespace at { class TensorBase; // XXX: This file exists because TensorImpl is in c10, but Dimname is in ATen. // Due to the c10/ATen library split, TensorImpl cannot depend on Dimname, // so we have a couple of workarounds. // // In the long term, we'll move Dimname to c10 and everything in this file // can be refactored out. The main blocker for that is that "c10::Symbol" // actually exists outside of c10 and needs to be moved in. // TensorImpl has a unique_ptr field. // XXX: Ideally we would just put optional> into TensorImpl. // // This class has an important invariant: there must be at least ONE // non-wildcard struct TORCH_API NamedTensorMeta final : public c10::NamedTensorMetaInterface { // This enum is to remind people that the invariant on constructors is that // the list of dimnames must have at least one non-wildcard enum HAS_NON_WILDCARD { HasNonWildcard }; explicit NamedTensorMeta(HAS_NON_WILDCARD, DimnameList names) : names_(names.vec()) { check_invariants(); } explicit NamedTensorMeta(HAS_NON_WILDCARD, std::vector&& names) : names_(std::move(names)) { check_invariants(); } std::unique_ptr clone() const override { return std::make_unique(HasNonWildcard, names_); } DimnameList names() const { return names_; } // Used for an assertion in TensorImpl.h int64_t slow_dim() const override { return names_.size(); } void check_invariants() const { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( std::any_of(names_.begin(), names_.end(), [](const Dimname& n) { return !n.isWildcard(); })); } void set_names(HAS_NON_WILDCARD, DimnameList new_names) { TORCH_INTERNAL_ASSERT(new_names.size() == names_.size()); std::copy(new_names.begin(), new_names.end(), names_.begin()); check_invariants(); } void set_names(HAS_NON_WILDCARD, std::vector&& new_names) { TORCH_INTERNAL_ASSERT(new_names.size() == names_.size()); names_ = std::move(new_names); check_invariants(); } // INVARIANT: at least one Dimname is non-WILDCARD std::vector names_; }; // When NamesMode is disabled, then all operations ignore tensors' names fields. // Concretely speaking, all tensors are treated as having nullopt names. struct TORCH_API NamesMode { static bool is_enabled(); static void set_enabled(bool enabled); }; // A RAII, thread local (!) guard that enables or disables names upon // construction, and sets it back to the original value upon destruction. struct TORCH_API NoNamesGuard { NoNamesGuard() : prev_mode(NamesMode::is_enabled()), initialized(true) { NamesMode::set_enabled(false); } ~NoNamesGuard() { if (initialized) { reset(); } } void reset() { TORCH_INTERNAL_ASSERT(initialized); NamesMode::set_enabled(prev_mode); } private: bool prev_mode; bool initialized; }; void check_names_valid_for(const TensorBase& tensor, DimnameList names); void check_names_valid_for(size_t tensor_dim, DimnameList names); // Sets the names of `tensor` to be `names`. TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, c10::optional names); TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::vector&& names, bool validate_names); constexpr size_t kMaxNamedTensorDim = 64; DimnameList default_names(size_t len); namespace impl { // Some helper functions on TensorImpl. Useful for working with names in TH. // XXX: Ideally these would exist as methods on TensorImpl TORCH_API void internal_set_names_inplace(TensorImpl* impl, c10::optional names, bool validate_names); TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::vector&& names, bool validate_names); void check_names_valid_for(TensorImpl* impl, DimnameList names); // Returns true if the tensor's names exist and are not all 'None'. // Returns false if the tensor's names don't exist (were not allocated), // or if all names are 'None'. // We treat not-allocated-names the same as allocated names that are all 'None'. TORCH_API bool has_names(const TensorImpl* impl); // Returns the names of the tensor's dimensions. // Unnamed tensors are treated as having 'None' in all dimension; this method // would return a DimnameList of all 'None's for an unnamed tensor. TORCH_API DimnameList get_names(const TensorImpl* impl); // This is more of an implementation detail; one should use impl::get_names / // Tensor::names() whenever possible because it provides a cleaner API. // Returns the names of the tensor if they have been allocated; returns nullopt // instead if the haven't been. The names of a tensor are not allocated if a // tensor is constructed with names=None. TORCH_API c10::optional get_opt_names(const TensorImpl* impl); } // namespace impl } // namespace at