#pragma once #include #include #include #include #include #include #include #include #include #include namespace at { class Tensor; } namespace c10 { struct IValue; template class List; struct Type; namespace detail { struct ListImpl final : public c10::intrusive_ptr_target { using list_type = std::vector; explicit ListImpl(list_type list_, TypePtr elementType_) : list(std::move(list_)) , elementType(std::move(elementType_)) {} list_type list; TypePtr elementType; intrusive_ptr copy() const { return make_intrusive(list, elementType); } friend TORCH_API bool operator==(const ListImpl& lhs, const ListImpl& rhs); }; } namespace impl { template class ListIterator; template class ListElementReference; template void swap(ListElementReference&& lhs, ListElementReference&& rhs); template bool operator==(const ListElementReference& lhs, const T& rhs); template bool operator==(const T& lhs, const ListElementReference& rhs); template struct ListElementConstReferenceTraits { // In the general case, we use IValue::to(). using const_reference = typename c10::detail::ivalue_to_const_ref_overload_return::type; }; // There is no to() overload for c10::optional. template<> struct ListElementConstReferenceTraits> { using const_reference = c10::optional>; }; template class ListElementReference final { public: operator std::conditional_t< std::is_reference::type>::value, const T&, T>() const; ListElementReference& operator=(T&& new_value) &&; ListElementReference& operator=(const T& new_value) &&; // assigning another ref to this assigns the underlying value ListElementReference& operator=(ListElementReference&& rhs) &&; const IValue& get() const& { return *iterator_; } friend void swap(ListElementReference&& lhs, ListElementReference&& rhs); private: ListElementReference(Iterator iter) : iterator_(iter) {} ListElementReference(const ListElementReference&) = delete; ListElementReference& operator=(const ListElementReference&) = delete; // allow moving, but only our friends (i.e. the List class) can move us ListElementReference(ListElementReference&&) noexcept = default; ListElementReference& operator=(ListElementReference&& rhs) & noexcept { iterator_ = std::move(rhs.iterator_); return *this; } friend class List; friend class ListIterator; Iterator iterator_; }; // this wraps vector::iterator to make sure user code can't rely // on it being the type of the underlying vector. template class ListIterator final : public std::iterator< std::random_access_iterator_tag, T, std::ptrdiff_t, T*, ListElementReference> { public: explicit ListIterator() = default; ~ListIterator() = default; ListIterator(const ListIterator&) = default; ListIterator(ListIterator&&) noexcept = default; ListIterator& operator=(const ListIterator&) = default; ListIterator& operator=(ListIterator&&) = default; ListIterator& operator++() { ++iterator_; return *this; } ListIterator operator++(int) { ListIterator copy(*this); ++*this; return copy; } ListIterator& operator--() { --iterator_; return *this; } ListIterator operator--(int) { ListIterator copy(*this); --*this; return copy; } ListIterator& operator+=(typename List::size_type offset) { iterator_ += offset; return *this; } ListIterator& operator-=(typename List::size_type offset) { iterator_ -= offset; return *this; } ListIterator operator+(typename List::size_type offset) const { return ListIterator{iterator_ + offset}; } ListIterator operator-(typename List::size_type offset) const { return ListIterator{iterator_ - offset}; } friend typename std::iterator::difference_type operator-(const ListIterator& lhs, const ListIterator& rhs) { return lhs.iterator_ - rhs.iterator_; } ListElementReference operator*() const { return {iterator_}; } ListElementReference operator[](typename List::size_type offset) const { return {iterator_ + offset}; } private: explicit ListIterator(Iterator iterator): iterator_(std::move(iterator)) {} Iterator iterator_; friend bool operator==(const ListIterator& lhs, const ListIterator& rhs) { return lhs.iterator_ == rhs.iterator_; } friend bool operator!=(const ListIterator& lhs, const ListIterator& rhs) { return !(lhs == rhs); } friend bool operator<(const ListIterator& lhs, const ListIterator& rhs) { return lhs.iterator_ < rhs.iterator_; } friend bool operator<=(const ListIterator& lhs, const ListIterator& rhs) { return lhs.iterator_ <= rhs.iterator_; } friend bool operator>(const ListIterator& lhs, const ListIterator& rhs) { return lhs.iterator_ > rhs.iterator_; } friend bool operator>=(const ListIterator& lhs, const ListIterator& rhs) { return lhs.iterator_ >= rhs.iterator_; } friend class ListIterator; friend class List; }; template List toTypedList(List list); template List toList(List&& list); template List toList(const List& list); const IValue* ptr_to_first_element(const List& list); } /** * An object of this class stores a list of values of type T. * * This is a pointer type. After a copy, both Lists * will share the same storage: * * > List a; * > List b = a; * > b.push_back("three"); * > ASSERT("three" == a.get(0)); * * We use this class in the PyTorch kernel API instead of * std::vector, because that allows us to do optimizations * and switch out the underlying list implementation without * breaking backwards compatibility for the kernel API. */ template class List final { private: // This is an intrusive_ptr because List is a pointer type. // Invariant: This will never be a nullptr, there will always be a valid // ListImpl. c10::intrusive_ptr impl_; using internal_reference_type = impl::ListElementReference; using internal_const_reference_type = typename impl::ListElementConstReferenceTraits::const_reference; public: using value_type = T; using size_type = typename c10::detail::ListImpl::list_type::size_type; using iterator = impl::ListIterator; using const_iterator = impl::ListIterator; using reverse_iterator = impl::ListIterator; /** * Constructs an empty list. */ explicit List(); /** * Constructs a list with some initial values. * Example: * List a({2, 3, 4}); */ List(std::initializer_list initial_values); explicit List(ArrayRef initial_values); /** * Create a generic list with runtime type information. * This only works for c10::impl::GenericList and is not part of the public API * but only supposed to be used internally by PyTorch. */ explicit List(TypePtr elementType); List(const List&) = default; List& operator=(const List&) = default; /** * Create a new List pointing to a deep copy of the same data. * The List returned is a new list with separate storage. * Changes in it are not reflected in the original list or vice versa. */ List copy() const; /** * Returns the element at specified location pos, with bounds checking. * If pos is not within the range of the container, an exception of type std::out_of_range is thrown. */ value_type get(size_type pos) const; /** * Moves out the element at the specified location pos and returns it, with bounds checking. * If pos is not within the range of the container, an exception of type std::out_of_range is thrown. * The list contains an invalid element at position pos afterwards. Any operations * on it before re-setting it are invalid. */ value_type extract(size_type pos) const; /** * Returns a reference to the element at specified location pos, with bounds checking. * If pos is not within the range of the container, an exception of type std::out_of_range is thrown. * * You cannot store the reference, but you can read it and assign new values to it: * * List list = ...; * list[2] = 5; * int64_t v = list[1]; */ internal_const_reference_type operator[](size_type pos) const; internal_reference_type operator[](size_type pos); /** * Assigns a new value to the element at location pos. */ void set(size_type pos, const value_type& value) const; /** * Assigns a new value to the element at location pos. */ void set(size_type pos, value_type&& value) const; /** * Returns an iterator to the first element of the container. * If the container is empty, the returned iterator will be equal to end(). */ iterator begin() const; /** * Returns an iterator to the element following the last element of the container. * This element acts as a placeholder; attempting to access it results in undefined behavior. */ iterator end() const; /** * Checks if the container has no elements. */ bool empty() const; /** * Returns the number of elements in the container */ size_type size() const; /** * Increase the capacity of the vector to a value that's greater or equal to new_cap. */ void reserve(size_type new_cap) const; /** * Erases all elements from the container. After this call, size() returns zero. * Invalidates any references, pointers, or iterators referring to contained elements. Any past-the-end iterators are also invalidated. */ void clear() const; /** * Inserts value before pos. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ iterator insert(iterator pos, const T& value) const; /** * Inserts value before pos. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ iterator insert(iterator pos, T&& value) const; /** * Inserts a new element into the container directly before pos. * The new element is constructed with the given arguments. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ template iterator emplace(iterator pos, Args&&... value) const; /** * Appends the given element value to the end of the container. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ void push_back(const T& value) const; /** * Appends the given element value to the end of the container. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ void push_back(T&& value) const; /** * Appends the given list to the end of the container. Uses at most one memory allocation. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ void append(List lst) const; /** * Appends the given element value to the end of the container. * The new element is constructed with the given arguments. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ template void emplace_back(Args&&... args) const; /** * Removes the element at pos. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ iterator erase(iterator pos) const; /** * Removes the elements in the range [first, last). * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ iterator erase(iterator first, iterator last) const; /** * Removes the last element of the container. * Calling pop_back on an empty container is undefined. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ void pop_back() const; /** * Resizes the container to contain count elements. * If the current size is less than count, additional default-inserted elements are appended. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ void resize(size_type count) const; /** * Resizes the container to contain count elements. * If the current size is less than count, additional copies of value are appended. * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. */ void resize(size_type count, const T& value) const; /** * Value equality comparison. This function implements Python-like semantics for * equality: two lists with the same identity (e.g. same pointer) trivially * compare equal, otherwise each element is compared for equality. */ template friend bool operator==(const List& lhs, const List& rhs); template friend bool operator!=(const List& lhs, const List& rhs); /** * Identity comparison. Returns true if and only if `rhs` represents the same * List object as `this`. */ bool is(const List& rhs) const; std::vector vec() const; /** * Returns the number of Lists currently pointing to this same list. * If this is the only instance pointing to this list, returns 1. */ // TODO Test use_count size_t use_count() const; TypePtr elementType() const; // See [unsafe set type] for why this exists. void unsafeSetElementType(TypePtr t); private: explicit List(c10::intrusive_ptr&& elements); explicit List(const c10::intrusive_ptr& elements); friend struct IValue; template friend List impl::toTypedList(List); template friend List impl::toList(List&&); template friend List impl::toList(const List&); friend const IValue* impl::ptr_to_first_element(const List& list); }; namespace impl { // GenericList is how IValue stores lists. It is, however, not part of the // public API. Kernels should use Lists with concrete types instead // (maybe except for some internal prim ops). using GenericList = List; inline const IValue* ptr_to_first_element(const GenericList& list) { return &list.impl_->list[0]; } } } namespace torch { template using List = c10::List; } #include // IWYU pragma: keep