#pragma once #include #include namespace c10 { template decltype(auto) getTypePtr(); std::string toString(const Type& type); template List::List(c10::intrusive_ptr&& elements) : impl_(std::move(elements)) {} template List::List(const c10::intrusive_ptr& elements) : impl_(elements) {} template List::List() : List(make_intrusive( typename c10::detail::ListImpl::list_type(), getTypePtr())) { static_assert(!std::is_same::value, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType) instead."); } template List::List(ArrayRef values) : List(make_intrusive( typename c10::detail::ListImpl::list_type(), getTypePtr())) { static_assert(!std::is_same::value, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType)."); impl_->list.reserve(values.size()); for (const T& element : values) { impl_->list.push_back(element); } } template List::List(std::initializer_list initial_values) : List(ArrayRef(initial_values)) { static_assert(!std::is_same::value, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType)."); } template List::List(TypePtr elementType) : List(make_intrusive( typename c10::detail::ListImpl::list_type(), std::move(elementType))) { static_assert(std::is_same::value || std::is_same>::value, "This constructor is only valid for c10::impl::GenericList or List."); } namespace impl { template List toTypedList(impl::GenericList list) { // If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant // because upcasting would allow people to add types into the new list that would break the old list. // However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can // allow upcasting. This can be a perf improvement since we can cast List to List> // without having to copy it. This is also used to provide backwards compatibility with some old models // that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_ // as List before we changed that argument to be List>. When deserializing, we // have list.use_count() == 1 and can deserialize the List directly as List>. TORCH_CHECK(*list.impl_->elementType == *getTypePtr() || (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(*getTypePtr())) , "Tried to cast a List<", toString(*list.impl_->elementType), "> to a List<", toString(*getTypePtr()), ">. Types mismatch."); return List(std::move(list.impl_)); } template impl::GenericList toList(List&& list) { return GenericList(std::move(list.impl_)); } template impl::GenericList toList(const List& list) { return GenericList(list.impl_); } } template List List::copy() const { return List(impl_->copy()); } namespace detail { template T list_element_to(T element) { return element; } template T list_element_to(const IValue& element) { return element.template to(); } template T list_element_to(IValue&& element) { return std::move(element).template to(); } template struct ListElementFrom { static IValue from(const T& element) { return element; } static IValue from(T&& element) { return std::move(element); } }; template<> struct ListElementFrom { static const IValue& from(const IValue& element) { return element; } static IValue&& from(IValue&& element) { return std::move(element); } }; } namespace impl { template ListElementReference::operator std::conditional_t< std::is_reference::type>::value, const T&, T>() const { return iterator_->template to(); } template ListElementReference& ListElementReference::operator=(T&& new_value) && { *iterator_ = c10::detail::ListElementFrom::from(std::move(new_value)); return *this; } template ListElementReference& ListElementReference::operator=(const T& new_value) && { *iterator_ = c10::detail::ListElementFrom::from(std::move(new_value)); return *this; } template ListElementReference& ListElementReference::operator=(ListElementReference&& rhs) && { *iterator_ = *rhs.iterator_; return *this; } template void swap(ListElementReference&& lhs, ListElementReference&& rhs) { std::swap(*lhs.iterator_, *rhs.iterator_); } template bool operator==(const ListElementReference& lhs, const T& rhs) { T lhs_tmp = lhs; return lhs_tmp == rhs; } template inline bool operator==(const T& lhs, const ListElementReference& rhs) { return rhs == lhs; } template inline typename ListElementConstReferenceTraits::const_reference list_element_to_const_ref(const IValue& element) { return element.template to(); } template<> inline typename ListElementConstReferenceTraits>::const_reference list_element_to_const_ref>(const IValue& element) { return element.toOptionalStringRef(); } } // namespace impl template void List::set(size_type pos, const value_type& value) const { impl_->list.at(pos) = c10::detail::ListElementFrom::from(value); } template void List::set(size_type pos, value_type&& value) const { impl_->list.at(pos) = c10::detail::ListElementFrom::from(std::move(value)); } template typename List::value_type List::get(size_type pos) const { return c10::detail::list_element_to(impl_->list.at(pos)); } template typename List::internal_const_reference_type List::operator[](size_type pos) const { return c10::impl::list_element_to_const_ref(impl_->list.at(pos)); } template typename List::internal_reference_type List::operator[](size_type pos) { static_cast(impl_->list.at(pos)); // Throw the exception if it is out of range. return {impl_->list.begin() + pos}; } template typename List::value_type List::extract(size_type pos) const { auto& elem = impl_->list.at(pos); auto result = c10::detail::list_element_to(std::move(elem)); // Reset the list element to a T() instead of None to keep it correctly typed elem = c10::detail::ListElementFrom::from(T{}); return result; } template typename List::iterator List::begin() const { return iterator(impl_->list.begin()); } template typename List::iterator List::end() const { return iterator(impl_->list.end()); } template bool List::empty() const { return impl_->list.empty(); } template typename List::size_type List::size() const { return impl_->list.size(); } template void List::reserve(size_type new_cap) const { impl_->list.reserve(new_cap); } template void List::clear() const { impl_->list.clear(); } template typename List::iterator List::insert(iterator pos, const T& value) const { return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom::from(value)) }; } template typename List::iterator List::insert(iterator pos, T&& value) const { return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom::from(std::move(value))) }; } template template typename List::iterator List::emplace(iterator pos, Args&&... value) const { // TODO Use list_element_from? return iterator { impl_->list.emplace(pos.iterator_, std::forward(value)...) }; } template void List::push_back(const T& value) const { impl_->list.push_back(c10::detail::ListElementFrom::from(value)); } template void List::push_back(T&& value) const { impl_->list.push_back(c10::detail::ListElementFrom::from(std::move(value))); } template void List::append(List b) const { if (b.use_count() == 1) { impl_->list.insert(impl_->list.end(), make_move_iterator(b.impl_->list.begin()), make_move_iterator(b.impl_->list.end())); } else { impl_->list.insert(impl_->list.end(), b.impl_->list.begin(), b.impl_->list.end()); } } template template void List::emplace_back(Args&&... args) const { // TODO Use list_element_from? impl_->list.push_back(T(std::forward(args)...)); } template typename List::iterator List::erase(iterator pos) const { return iterator { impl_->list.erase(pos.iterator_) }; } template typename List::iterator List::erase(iterator first, iterator last) const { return iterator { impl_->list.erase(first.iterator_, last.iterator_) }; } template void List::pop_back() const { impl_->list.pop_back(); } template void List::resize(size_type count) const { impl_->list.resize(count, T{}); } template void List::resize(size_type count, const T& value) const { impl_->list.resize(count, value); } template bool operator==(const List& lhs, const List& rhs) { // Lists with the same identity trivially compare equal. if (lhs.impl_ == rhs.impl_) { return true; } // Otherwise, just compare values directly. return *lhs.impl_ == *rhs.impl_; } template bool operator!=(const List& lhs, const List& rhs) { return !(lhs == rhs); } template bool List::is(const List& rhs) const { return this->impl_ == rhs.impl_; } template std::vector List::vec() const { std::vector result(begin(), end()); return result; } template size_t List::use_count() const { return impl_.use_count(); } template TypePtr List::elementType() const { return impl_->elementType; } template void List::unsafeSetElementType(TypePtr t) { impl_->elementType = std::move(t); } }