#pragma once #include #include namespace at { class Tensor; class OptionalTensorRef; } namespace c10 { namespace detail { /* * Specializations of `IListRefTagImplBase` that implement the default * implementation for `IListRefTag::Unboxed`. */ template class IListRefTagImplBase { public: using elem_type = ListElemT; using list_type = ArrayRef; /* * These `unwrap` static methods unwraps the inner containers out * of `IListRef` (and `IListRefIterator`). They are required when * the macro `TORCH_ILISTREF_UNWRAP` is called. */ static const list_type& unwrap(const IListRef& ilist) { return ilist.payload_.unboxed; } static typename list_type::const_iterator& unwrap(IListRefIterator& it) { return it.payload_.unboxed_iterator; } static const typename list_type::const_iterator& unwrap( const IListRefIterator& it) { return it.payload_.unboxed_iterator; } /* * We have these function (besides the `unwrap`s above) because the * implementation for both `IListRef::operator[]` and `IListRefIterator::operator*` * weren't syntatically equal for the existing tags at the time * (`Unboxed` and `Boxed`). */ static IListRefConstRef front(const list_type& lst) { return lst.front(); } static IListRefConstRef iterator_get( const typename list_type::const_iterator& it) { return *it; } }; /* * Specializations of `IListRefTagImplBase` that implement the default * implementation for `IListRefTag::Boxed`. */ template class IListRefTagImplBase { public: using elem_type = ListElemT; using list_type = List; static const list_type& unwrap(const IListRef& ilist) { return *ilist.payload_.boxed; } static typename list_type::const_iterator& unwrap(IListRefIterator& it) { return it.payload_.boxed_iterator; } static const typename list_type::const_iterator& unwrap( const IListRefIterator& it) { return it.payload_.boxed_iterator; } static IListRefConstRef front(const list_type& lst) { return lst[0]; } static IListRefConstRef iterator_get( const typename list_type::const_iterator& it) { return (*it).get().toTensor(); } }; /* * Specializations of `IListRefTagImplBase` that implement the default * implementation for `IListRefTag::Materialized`. */ template class IListRefTagImplBase> { public: using elem_type = MaterializedIListRefElem; using list_type = MaterializedIListRef; static const list_type& unwrap(const IListRef& ilist) { return *ilist.payload_.materialized; } static typename list_type::const_iterator& unwrap(IListRefIterator& it) { return it.payload_.materialized_iterator; } static const typename list_type::const_iterator& unwrap( const IListRefIterator& it) { return it.payload_.materialized_iterator; } static IListRefConstRef front(const list_type& lst) { return lst[0]; } static IListRefConstRef iterator_get( const typename list_type::const_iterator& it) { return *it; } }; /* * [Note: ITensorListRef] * Specializations necessary for `IListRef` type. * * Since the default implementations are usually done with supporting * `Tensor` in mind, we only have to inherit from the base implementations. */ template <> class IListRefTagImpl : public IListRefTagImplBase {}; template <> class IListRefTagImpl : public IListRefTagImplBase {}; template <> class IListRefTagImpl : public IListRefTagImplBase< IListRefTag::Materialized, at::Tensor, MaterializedIListRefElem> {}; /* * [Note: IOptTensorListRef] * Specializations necessary for `IListRef` type. * * We can't get an `at::OptionalTensorRef` directly from an instance of * `List>` (the type that corresponds to the boxed world). * * So, the default implementation won't help us. Thus, we have to implement * this method ourselves. */ template <> class IListRefTagImpl : public IListRefTagImplBase {}; template <> class IListRefTagImpl : public IListRefTagImplBase> { public: /* * Given an instance of the types corresponding to the `Boxed` tag, we override * the default implementation, so that we can return a `at::OptionalTensorRef`. */ static IListRefConstRef iterator_get( const typename list_type::const_iterator& it) { const auto& ivalue = (*it).get(); if (!ivalue.isNone()) { const auto& tensor = ivalue.toTensor(); return (tensor.defined()) ? tensor : at::OptionalTensorRef{}; } return {}; } }; template <> class IListRefTagImpl : public IListRefTagImplBase< IListRefTag::Materialized, at::OptionalTensorRef, MaterializedIListRefElem> {}; } // namespace detail } // namespace c10 namespace at { // [Note: ITensorListRef] using ITensorListRef = c10::IListRef; using ITensorListRefIterator = c10::IListRefIterator; using MaterializedITensorListRef = c10::detail::MaterializedIListRef; // [Note: IOptTensorListRef] using IOptTensorListRef = c10::IListRef; using IOptTensorListRefIterator = c10::IListRefIterator; using MaterializedIOptTensorListRef = c10::detail::MaterializedIListRef; } // namespace at