import inspect import functools from enum import Enum import torch.autograd class _SnapshotState(Enum): r""" These are the snapshotting-related states that IterDataPipes can be in. `NotStarted` - allows you to restore a snapshot and create an iterator with reset `Restored` - cannot restore again, allows you to create an iterator without resetting the DataPipe `Iterating` - can restore, will reset if you create a new iterator """ NotStarted = 0 Restored = 1 Iterating = 2 def _simplify_obj_name(obj) -> str: """ Simplify the display strings of objects for the purpose of rendering within DataPipe error messages. """ if inspect.isfunction(obj): return obj.__name__ else: return repr(obj) def _generate_input_args_string(obj): """ Generate a string for the input arguments of an object. """ signature = inspect.signature(obj.__class__) input_param_names = set() for param_name, _ in signature.parameters.items(): input_param_names.add(param_name) result = [] for name, obj in inspect.getmembers(obj): if name in input_param_names: result.append((name, _simplify_obj_name(obj))) return ', '.join([f'{name}={value}' for name, value in result]) def _generate_iterdatapipe_msg(datapipe): return f"{datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})" def _gen_invalid_iterdatapipe_msg(datapipe): return ("This iterator has been invalidated because another iterator has been created " f"from the same IterDataPipe: {_generate_iterdatapipe_msg(datapipe)}\n" "This may be caused multiple references to the same IterDataPipe. We recommend " "using `.fork()` if that is necessary.") _feedback_msg = ("\nFor feedback regarding this single iterator per IterDataPipe constraint, feel free " "to comment on this issue: https://github.com/pytorch/data/issues/45.") def _check_iterator_valid(datapipe, iterator_id, next_method_exists=False) -> None: r""" Given an instance of a DataPipe and an iterator ID, check if the IDs match, and if not, raises an exception. In the case of ChildDataPipe, the ID gets compared to the one stored in `main_datapipe` as well. """ if next_method_exists: # This is the case where `IterDataPipe` has both `__iter__` and `__next__`. # The `_valid_iterator_id` should either be never set (`None`), or set by at most one # iterator (`0`). Otherwise, it means there are multiple iterators. if datapipe._valid_iterator_id is not None and datapipe._valid_iterator_id != 0: extra_msg = "\nNote that this exception is raised inside your IterDataPipe's a `__next__` method" raise RuntimeError(_gen_invalid_iterdatapipe_msg(datapipe) + extra_msg + _feedback_msg) elif hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True: if hasattr(datapipe, "_check_valid_iterator_id"): if not datapipe._check_valid_iterator_id(iterator_id): raise RuntimeError("This iterator has been invalidated, because a new iterator has been created " f"from one of the ChildDataPipes of " f"{_generate_iterdatapipe_msg(datapipe.main_datapipe)}." + _feedback_msg) else: raise RuntimeError("ChildDataPipe must have method `_check_valid_iterator_id`.") elif datapipe._valid_iterator_id != iterator_id: raise RuntimeError(_gen_invalid_iterdatapipe_msg(datapipe) + _feedback_msg) def _set_datapipe_valid_iterator_id(datapipe): r""" Given a DataPipe, updates its valid iterator ID and reset the DataPipe. """ if hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True: if hasattr(datapipe, "_set_main_datapipe_valid_iterator_id"): datapipe._set_main_datapipe_valid_iterator_id() # reset() is called within this method when appropriate else: raise RuntimeError("ChildDataPipe must have method `_set_main_datapipe_valid_iterator_id`.") else: if datapipe._valid_iterator_id is None: datapipe._valid_iterator_id = 0 else: datapipe._valid_iterator_id += 1 datapipe.reset() return datapipe._valid_iterator_id def hook_iterator(namespace, profile_name): r""" Hook that is applied to all `__iter__` of metaclass `_DataPipeMeta`. This is done for the purpose of profiling and checking if an iterator is still valid. """ def profiler_record_fn_context(): return torch.autograd.profiler.record_function(profile_name) class IteratorDecorator: r""" Wrap the iterator and modifying its `__next__` method. This decorator is applied to DataPipes of which `__iter__` method is NOT a generator function. Those `__iter__` method commonly returns `self` but not necessarily. """ def __init__(self, iterator, source_dp, iterator_id, has_next_method): self.iterator = iterator self.source_dp = source_dp self.iterator_id = iterator_id self._profiler_enabled = torch.autograd._profiler_enabled() # Check if `__iter__` returns `self` and `DataPipe` has `__next__` self.self_and_has_next_method = self.iterator is self.source_dp and has_next_method def __iter__(self): return self def _get_next(self): r""" Return next with logic related to iterator validity, profiler, and incrementation of samples yielded. """ _check_iterator_valid(self.source_dp, self.iterator_id) result = next(self.iterator) if not self.self_and_has_next_method: self.source_dp._number_of_samples_yielded += 1 return result def __next__(self): # TODO: Add try-except to in-place reduce traceback from the Exception # See: https://github.com/pytorch/data/issues/284 if self._profiler_enabled: with profiler_record_fn_context(): return self._get_next() else: # Decided against using `contextlib.nullcontext` for performance reasons return self._get_next() def __getattr__(self, name): return getattr(self.iterator, name) func = namespace['__iter__'] # ``__iter__`` of IterDataPipe is a generator function if inspect.isgeneratorfunction(func): @functools.wraps(func) def wrap_generator(*args, **kwargs): gen = func(*args, **kwargs) datapipe = args[0] if datapipe._fast_forward_iterator: it = datapipe._fast_forward_iterator datapipe._fast_forward_iterator = None datapipe._snapshot_state = _SnapshotState.Iterating while True: try: yield next(it) except StopIteration: return iterator_id = _set_datapipe_valid_iterator_id(datapipe) # This ID is tied to each created iterator _profiler_enabled = torch.autograd._profiler_enabled() try: if _profiler_enabled: with profiler_record_fn_context(): response = gen.send(None) else: response = gen.send(None) while True: datapipe._number_of_samples_yielded += 1 request = yield response # Pass through here every time `__next__` is called if _profiler_enabled: with profiler_record_fn_context(): _check_iterator_valid(datapipe, iterator_id) response = gen.send(request) else: # Decided against using `contextlib.nullcontext` for performance reasons _check_iterator_valid(datapipe, iterator_id) response = gen.send(request) except StopIteration as e: return except Exception as e: # TODO: Simplify the traceback message to skip over `response = gen.send(None)` # Part of https://github.com/pytorch/data/issues/284 datapipe = args[0] msg = "thrown by __iter__ of" single_iterator_msg = "single iterator per IterDataPipe constraint" if hasattr(e.args, '__len__'): full_msg = f"{msg} {datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})" if len(e.args) == 0 or not isinstance(e.args[0], str): # If an exception message doesn't exist e.args = (f'\nThis exception is {full_msg}',) elif msg not in e.args[0] and single_iterator_msg not in e.args[0]: e.args = (e.args[0] + f'\nThis exception is {full_msg}',) + e.args[1:] raise namespace['__iter__'] = wrap_generator else: # ``__iter__`` of IterDataPipe is NOT a generator function # IterDataPipe is an iterator with both ``__iter__`` and ``__next__`` # And ``__iter__`` may or may not return `self` if '__next__' in namespace: # If `__next__` exists, put a wrapper around it next_func = namespace['__next__'] @functools.wraps(next_func) def wrap_next(*args, **kwargs): if torch.autograd._profiler_enabled(): with profiler_record_fn_context(): result = next_func(*args, **kwargs) else: result = next_func(*args, **kwargs) datapipe = args[0] datapipe._number_of_samples_yielded += 1 return result namespace['__next__'] = wrap_next # Note that if the `__next__` and `__iter__` do something completely unrelated. It may cause issue but # the user will be violating the iterator protocol. Potential issue: # 1. Valid iterator ID may not update or checked properly # 2. The number of samples yielded will be miscounted # Regardless if `__next__` exists or not, `__iter__` needs a wrapper to track the number of valid iterators @functools.wraps(func) def wrap_iter(*args, **kwargs): iter_ret = func(*args, **kwargs) datapipe = args[0] datapipe._snapshot_state = _SnapshotState.Iterating if datapipe._fast_forward_iterator: iter_ret = datapipe._fast_forward_iterator datapipe._fast_forward_iterator = None return iter_ret iterator_id = _set_datapipe_valid_iterator_id(datapipe) # This ID is tied to each created iterator return IteratorDecorator(iter_ret, datapipe, iterator_id, '__next__' in namespace) namespace['__iter__'] = wrap_iter