#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. class CyclingIterator: """ An iterator decorator that cycles through the underlying iterator "n" times. Useful to "unroll" the dataset across multiple training epochs. The generator function is called as ``generator_fn(epoch)`` to obtain the underlying iterator, where ``epoch`` is a number less than or equal to ``n`` representing the ``k``th cycle For example if ``generator_fn`` always returns ``[1,2,3]`` then ``CyclingIterator(n=2, generator_fn)`` will iterate through ``[1,2,3,1,2,3]`` """ def __init__(self, n: int, generator_fn, start_epoch=0): self._n = n self._epoch = start_epoch self._generator_fn = generator_fn self._iter = generator_fn(self._epoch) def __iter__(self): return self def __next__(self): try: return next(self._iter) except StopIteration as eod: # eod == end of data if self._epoch < self._n - 1: self._epoch += 1 self._iter = self._generator_fn(self._epoch) return self.__next__() else: raise eod