import time import types from torch.utils.data import IterDataPipe, communication DEFAULT_NON_BLOCKING_SLEEP = 0.001 __all__ = [ "DataPipeBehindQueues", "EnsureNonBlockingDataPipe", "InvalidStateResetRequired", "NonBlocking", "NotAvailable", "QueueWrapper", "default_not_available_hook", ] def default_not_available_hook(): time.sleep(DEFAULT_NON_BLOCKING_SLEEP) class NotAvailable(Exception): pass class InvalidStateResetRequired(Exception): """ Returned by DataPipe when it is expecting to get reset request, for example RouterDataPipe expecting all workers to request reset' """ pass class NonBlocking(IterDataPipe): not_available_hook = default_not_available_hook def __iter__(self): self.reset_iterator() return self def __next__(self): while True: try: return self.nonblocking_next() except StopIteration: raise StopIteration except NotAvailable: if NonBlocking.not_available_hook is not None: NonBlocking.not_available_hook() def nonblocking_next(self): raise NotImplementedError( "nonblocking_next is not implemented for %s" % self.__class__) def reset_iterator(self): raise NotImplementedError( "reset_iterator is not implemented for %s" % self.__class__) @staticmethod def register_not_available_hook(hook_function): NonBlocking.not_available_hook = hook_function def EnsureNonBlockingDataPipe(validated_datapipe): if not isinstance(validated_datapipe, IterDataPipe): raise Exception('Not Iterable DataPipe ' + str(validated_datapipe.__class__)) if isinstance(validated_datapipe, NonBlocking): return validated_datapipe if not hasattr(validated_datapipe, '_as_iterator'): validated_datapipe._as_iterator = None # type: ignore[attr-defined] if not hasattr(validated_datapipe, 'nonblocking_next'): def nonblocking_next(self): if self._as_iterator is None: self._as_iterator = iter(self) return next(self._as_iterator) validated_datapipe.nonblocking_next = types.MethodType( # type: ignore[attr-defined] nonblocking_next, validated_datapipe) if not hasattr(validated_datapipe, 'reset_iterator'): def reset_iterator(self): self._as_iterator = None validated_datapipe.reset_iterator = types.MethodType( # type: ignore[attr-defined] reset_iterator, validated_datapipe) return validated_datapipe def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_request_get=False): """ Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue If raise_stop is true, raises exception when StopIteration received from the source_datapipe """ if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolServer): raise Exception('Expecting IterDataPipeQueueProtocolServer, got', protocol) source_datapipe = EnsureNonBlockingDataPipe(source_datapipe) forever = True while forever: try: # Non-blocking call is Extremely slow here for python.mp, need to figure out a good workaround request = protocol.get_new_request(block=blocking_request_get) except communication.protocol.EmptyQueue: yield True continue if isinstance(request, communication.messages.ResetIteratorRequest): source_datapipe.reset_iterator() protocol.response_reset_iterator() elif isinstance(request, communication.messages.TerminateRequest): forever = False protocol.response_terminate() elif isinstance(request, communication.messages.GetNextRequest): while forever: try: value = source_datapipe.nonblocking_next() except NotAvailable: yield True continue except StopIteration: protocol.response_stop_iteration() if full_stop: forever = False else: yield True break except InvalidStateResetRequired: protocol.response_invalid_state() if full_stop: forever = False else: yield True break protocol.response_next(value) yield True # Returns control break else: raise Exception('Unrecognized type of request received', request) class QueueWrapper(NonBlocking): """ Creates iter.DataPipe which reads data from the DataLoader.Queue """ def __init__(self, protocol, response_wait_time=0.00001): if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolClient): raise Exception('Got', protocol) self.protocol = protocol self.counter = 0 self._stop_iteration = False self._response_wait_time = response_wait_time def reset_iterator(self): self._stop_iteration = False self.counter = 0 self.protocol.request_reset_iterator() while True: try: self.protocol.get_response_reset_iterator() break except communication.protocol.EmptyQueue: if NonBlocking.not_available_hook is not None: NonBlocking.not_available_hook() def nonblocking_next(self): if self._stop_iteration: raise Exception( '`next` or `nonblocking_next` called after receiving StopIteration') if self.protocol.can_take_request(): self.protocol.request_next() try: response = self.protocol.get_response_next(block=True, timeout=self._response_wait_time) except communication.protocol.EmptyQueue: raise NotAvailable if isinstance(response, communication.messages.StopIterationResponse): self._stop_iteration = True raise StopIteration if isinstance(response, communication.messages.InvalidStateResponse): raise NotAvailable return response.value