from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import MapDataPipe, DataChunk from typing import List, Optional, Sized, TypeVar __all__ = ["BatcherMapDataPipe", ] T = TypeVar('T') @functional_datapipe('batch') class BatcherMapDataPipe(MapDataPipe[DataChunk]): r""" Create mini-batches of data (functional name: ``batch``). An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, or ``length % batch_size`` for the last batch if ``drop_last`` is set to ``False``. Args: datapipe: Iterable DataPipe being batched batch_size: The size of each batch drop_last: Option to drop the last batch if it's not full Example: >>> # xdoctest: +SKIP >>> from torchdata.datapipes.map import SequenceWrapper >>> dp = SequenceWrapper(range(10)) >>> batch_dp = dp.batch(batch_size=2) >>> list(batch_dp) [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] """ datapipe: MapDataPipe batch_size: int drop_last: bool length: Optional[int] def __init__(self, datapipe: MapDataPipe[T], batch_size: int, drop_last: bool = False, wrapper_class=DataChunk, ) -> None: assert batch_size > 0, "Batch size is required to be larger than 0!" super().__init__() self.datapipe = datapipe self.batch_size = batch_size self.drop_last = drop_last self.length = None self.wrapper_class = wrapper_class def __getitem__(self, index) -> DataChunk: batch: List = [] indices = range(index * self.batch_size, (index + 1) * self.batch_size) try: for i in indices: batch.append(self.datapipe[i]) return self.wrapper_class(batch) except IndexError: if not self.drop_last and len(batch) > 0: return self.wrapper_class(batch) else: raise IndexError(f"Index {index} is out of bound.") def __len__(self) -> int: if self.length is not None: return self.length if isinstance(self.datapipe, Sized): if self.drop_last: self.length = len(self.datapipe) // self.batch_size else: self.length = (len(self.datapipe) + self.batch_size - 1) // self.batch_size return self.length raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))