import torch import threading import pickle from torch.utils.data import IterDataPipe, communication, MapDataPipe try: import dill # XXX: By default, dill writes the Pickler dispatch table to inject its # own logic there. This globally affects the behavior of the standard library # pickler for any user who transitively depends on this module! # Undo this extension to avoid altering the behavior of the pickler globally. dill.extend(use_dill=False) HAS_DILL = True except ImportError: HAS_DILL = False __all__ = [ "DataPipeToQueuesLoop", "SpawnProcessForDataPipeline", "SpawnThreadForDataPipeline", ] def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue): if isinstance(source_datapipe, IterDataPipe): pipe_type = communication.iter protocol_type = communication.protocol.IterDataPipeQueueProtocolServer elif isinstance(source_datapipe, MapDataPipe): pipe_type = communication.map # type: ignore[misc] protocol_type = communication.protocol.MapDataPipeQueueProtocolServer # type: ignore[assignment] else: raise Exception('Only supports IterDataPipe or MapDataPipe, got', source_datapipe) torch.set_num_threads(1) for _ in pipe_type.DataPipeBehindQueues(source_datapipe, protocol_type(req_queue, res_queue), blocking_request_get=True): pass def SpawnProcessForDataPipeline(multiprocessing_ctx, datapipe): req_queue = multiprocessing_ctx.Queue() res_queue = multiprocessing_ctx.Queue() process = multiprocessing_ctx.Process( target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue)) return process, req_queue, res_queue def SpawnThreadForDataPipeline(datapipe): r""" Given a DataPipe, creates a copy of the DataPipe, starts a new Thread with DataPipeToQueuesLoop as target, and return the process, req_queue, res_queue, thread_local_datapipe. """ req_queue = communication.queue.ThreadingQueue() res_queue = communication.queue.ThreadingQueue() try: new_datapipe = pickle.loads(pickle.dumps(datapipe)) except Exception as pe: if HAS_DILL: try: new_datapipe = dill.loads(dill.dumps(datapipe)) except Exception as de: raise Exception('Unable to dill DataPipe to make thread local copy', de) else: raise Exception('Unable to pickle DataPipe to make thread local copy (consider installing `dill`)', pe) process = threading.Thread(target=DataPipeToQueuesLoop, args=( new_datapipe, req_queue, res_queue), daemon=True) return process, req_queue, res_queue, new_datapipe