import time import types from torch.utils.data import communication, MapDataPipe DEFAULT_NON_BLOCKING_SLEEP = 0.001 __all__ = [ "DataPipeBehindQueues", "EnsureNonBlockingMapDataPipe", "NonBlockingMap", "NotAvailable", "QueueWrapperForMap", "default_not_available_hook", ] def default_not_available_hook(): time.sleep(DEFAULT_NON_BLOCKING_SLEEP) class NotAvailable(Exception): pass class NonBlockingMap(MapDataPipe): not_available_hook = default_not_available_hook def __getitem__(self, index): while True: try: return self.nonblocking_getitem(index) except NotAvailable: if NonBlockingMap.not_available_hook is not None: NonBlockingMap.not_available_hook() def __len__(self): try: return self.nonblocking_len() except NotAvailable: if NonBlockingMap.not_available_hook is not None: NonBlockingMap.not_available_hook() def nonblocking_len(self): raise NotImplementedError( "nonblocking_len is not implemented for %s" % self.__class__) def nonblocking_getitem(self, index): raise NotImplementedError( "nonblocking_getitem is not implemented for %s" % self.__class__) @staticmethod def register_not_available_hook(hook_function): NonBlockingMap.not_available_hook = hook_function def EnsureNonBlockingMapDataPipe(validated_datapipe): if not isinstance(validated_datapipe, MapDataPipe): raise Exception(f'Not Map DataPipe - got {validated_datapipe.__class__}') if isinstance(validated_datapipe, NonBlockingMap): return validated_datapipe if not hasattr(validated_datapipe, 'nonblocking_len'): def nonblocking_len(self): return self.__len__() validated_datapipe.nonblocking_len = types.MethodType( # type: ignore[attr-defined] nonblocking_len, validated_datapipe) if not hasattr(validated_datapipe, 'nonblocking_getitem'): def nonblocking_getitem(self, index): return self.__getitem__(index) validated_datapipe.nonblocking_getitem = types.MethodType( # type: ignore[attr-defined] nonblocking_getitem, validated_datapipe) return validated_datapipe def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_request_get=False): """ Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue If raise_stop is true, raises exception when StopIteration received from the source_datapipe """ if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolServer): raise Exception('Expecting MapDataPipeQueueProtocolServer, got', protocol) source_datapipe = EnsureNonBlockingMapDataPipe(source_datapipe) forever = True while forever: try: # Non-blocking call is Extremely slow here for python.mp, need to figure out a good workaround request = protocol.get_new_request(block=blocking_request_get) except communication.protocol.EmptyQueue: yield True continue if isinstance(request, communication.messages.TerminateRequest): forever = False protocol.response_terminate() elif isinstance(request, communication.messages.LenRequest): size = source_datapipe.nonblocking_len() protocol.response_len(size) elif isinstance(request, communication.messages.GetItemRequest): while forever: try: value = source_datapipe.nonblocking_getitem(request.key) except NotAvailable: yield True continue except IndexError as e: # Alternatively, we can just allow the underlying DataPipe to throw an exception? protocol.response_index_out_of_bound() if full_stop: forever = False else: yield True break protocol.response_item(request.key, value) yield True # Returns control break else: raise Exception('Unrecognized type of request received', request) class QueueWrapperForMap(NonBlockingMap): """ Creates map.DataPipe which reads data from the DataLoader.Queue """ def __init__(self, protocol, response_wait_time=0.00001): if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolClient): raise Exception('Got', protocol) self.protocol = protocol self.counter = 0 self._stop_iteration = False self._response_wait_time = response_wait_time def nonblocking_getitem(self, index): if self._stop_iteration: raise Exception( '`getitem` or `nonblocking_getitem` called after receiving StopIteration') if self.protocol.can_take_request(): self.protocol.request_item(index) try: response = self.protocol.get_response_item(block=True, timeout=self._response_wait_time) except communication.protocol.EmptyQueue: raise NotAvailable if isinstance(response, communication.messages.StopIterationResponse): self._stop_iteration = True raise IndexError(f"Index {index} is out of bound.") return response.key, response.value def nonblocking_len(self): if self._stop_iteration: raise Exception( '`len` or `nonblocking_len` called after receiving StopIteration') if self.protocol.can_take_request(): self.protocol.request_len() try: response = self.protocol.get_response_len(block=True, timeout=self._response_wait_time) except communication.protocol.EmptyQueue: raise NotAvailable return response.len