import torch import time import torch.distributed.rpc as rpc from torch.distributed.rpc.api import _delete_all_user_and_unforked_owner_rrefs from torch.testing._internal.dist_utils import ( dist_init, wait_until_pending_futures_and_users_flushed, wait_until_owners_and_forks_on_rank, worker_name, ) from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( RpcAgentTestFixture, ) def my_sleep_func(seconds=1): time.sleep(seconds) return torch.mul(torch.tensor(1), torch.tensor(1)) @torch.jit.script def my_script_func(tensor): return torch.add(tensor, tensor) def add_rref_to_value(rref, value): return rref.to_here() + value class FaultyAgentRpcTest(RpcAgentTestFixture): # no faulty_messages defined so this fails all retryable messages - see # faulty_rpc_agent_test_fixture.py for the list of retryable messages. @dist_init(messages_to_delay={}) def test_check_failed_messages(self): if self.rank == 0: dst_worker_b = worker_name((self.rank + 1) % self.world_size) dst_worker_c = worker_name((self.rank + 2) % self.world_size) # Worker0 sends RPC to Worker1 and creates an RRef there rref = rpc.remote(dst_worker_b, torch.add, args=(torch.ones(2, 2), torch.ones(2, 2))) # Worker0 sends an RPC to Worker2 with the RRef as an arg rpc.remote(dst_worker_c, add_rref_to_value, args=(rref, torch.ones(2, 2))) # check if the output is as expected self.assertEqual(rref.to_here(), torch.add(torch.ones(2, 2), torch.ones(2, 2))) # explicitly delete all User RRefs _delete_all_user_and_unforked_owner_rrefs() @dist_init def test_verify_backend_options(self): self.assertEqual(self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE) self.assertEqual(self.rpc_backend_options.num_worker_threads, 8) self.assertEqual(self.rpc_backend_options.num_fail_sends, 3) self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4) self.assertEqual(len(self.rpc_backend_options.messages_to_delay), 2) self.assertEqual(self.rpc_backend_options.rpc_timeout, rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) @dist_init(faulty_messages=["RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"]) def test_custom_faulty_messages(self): self.assertEqual( set(["RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"]), set(self.rpc_backend_options.messages_to_fail), ) @dist_init(faulty_messages=[]) def test_no_faulty_messages(self): self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 0) @dist_init(messages_to_delay={"SCRIPT_CALL": 1.5}) def test_custom_messages_to_delay(self): self.assertEqual(self.rpc_backend_options.messages_to_delay, {"SCRIPT_CALL": 1.5}) def _test_remote_message_dropped_pickle(self, dst=None): if self.rank != 0: return dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) # Since we fail python_remote_call messages synchronously, the future # corresponding to this remote call will be marked with an error when # this function returns. rref = rpc.remote(dst_worker, my_sleep_func, args=(1,)) # Call to ensure pending callbacks are run. wait_until_pending_futures_and_users_flushed() # Attempt to fork the RRef should raise an error indicating the rpc.remote timeout. with self.assertRaisesRegex(RuntimeError, "RRef creation"): rref._serialize() # Test that using RRef as arg over RPC (which forks) results in the same # error with self.assertRaisesRegex(RuntimeError, "RRef creation"): rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 1)) @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"]) def test_remote_message_dropped_pickle(self): self._test_remote_message_dropped_pickle() @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"]) def test_remote_message_dropped_pickle_to_self(self): self._test_remote_message_dropped_pickle(self.rank) def _test_remote_message_dropped_timeout(self, func, args, dst=None): if self.rank != 0: return # test the case where rpc.remote() message creation is completely dropped. dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) # Since we fail python_remote_call messages synchronously, the future # corresponding to this remote call will be marked with an error when # this function returns. rref = rpc.remote(dst_worker, func, args=args) # Call to ensure pending callbacks are run. wait_until_pending_futures_and_users_flushed() with self.assertRaisesRegex(RuntimeError, "RRef creation"): rref.to_here() # Note: during shutdown, logs will indicate "Could not find OwnerRRef..." # on the owning nodes, this is expected because the OwnerRRef was never # successfully created. Therefore, delAllUsers will work as expected. @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"]) def test_builtin_remote_message_dropped_timeout(self): func = torch.add args = (torch.tensor(1), torch.tensor(1)) self._test_remote_message_dropped_timeout(func, args) @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"]) def test_builtin_remote_message_dropped_timeout_to_self(self): func = torch.add args = (torch.tensor(1), torch.tensor(1)) self._test_remote_message_dropped_timeout(func, args, dst=0) @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"]) def test_udf_remote_message_dropped_timeout(self): func = my_sleep_func args = (2,) self._test_remote_message_dropped_timeout(func, args) @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"]) def test_udf_remote_message_dropped_timeout_to_self(self): func = my_sleep_func args = (2,) self._test_remote_message_dropped_timeout(func, args, dst=0) def _test_remote_message_delay_timeout(self, func, args, dst=None): if self.rank != 0: return # Test the case where remote message is eventually processed on the owner, # but the future on the creator times out before the response comes back. dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) # 10 ms timeout rref = rpc.remote(dst_worker, func, args=args, timeout=0.001) # Future corresponding to the remote creation should time out. expected_error = self.get_timeout_error_regex() with self.assertRaisesRegex(RuntimeError, expected_error): rref._get_future().wait() # Call to ensure pending callbacks are run. wait_until_pending_futures_and_users_flushed() # to_here() should now pick up that rpc.remote() creation has failed. with self.assertRaisesRegex(RuntimeError, "RRef creation"): rref.to_here() # Test the case where rpc.remote() times out, but to_here() has already # started blocking before. # NOTE: we only test this when not sending to self, as to_here() calls # calls localValue(), which does not send an RPC and thus does not have # a timeout. This can be supported by allowing future.wait() to # take in an optional timeout (https://github.com/pytorch/pytorch/issues/39280) if dst_rank != self.rank: slow_rref = rpc.remote(dst_worker, func, args=args, timeout=2) with self.assertRaisesRegex(RuntimeError, expected_error): # to_here() should raise timeout error, since it does not know about the # status of rpc.remote(). slow_rref.to_here(0.001) # Note: If we proceed with shutdown, UserRRef will send out a RRefUserDelete # but this can be a noop since it may not exist on the owner yet. Later, # the owner can process the RRef creation and wait for the delete message, # thus leading to a timeout. # Therefore, we wait until we get notification that pending owners have # been confirmed before sending out RRefUserDeletes. if dst_rank != self.rank: wait_until_owners_and_forks_on_rank(2, 2, rank=dst_rank) @dist_init(faulty_messages=[], messages_to_delay={"PYTHON_REMOTE_CALL": 2}) def test_udf_remote_message_delay_timeout(self): func = my_sleep_func args = (2,) self._test_remote_message_delay_timeout(func, args) @dist_init(faulty_messages=[], messages_to_delay={"PYTHON_REMOTE_CALL": 2}) def test_udf_remote_message_delay_timeout_to_self(self): func = my_sleep_func args = (1,) self._test_remote_message_delay_timeout(func, args, dst=0) @dist_init( faulty_messages=[], messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1}, ) def test_remote_message_builtin_delay_timeout(self): func = torch.add args = (torch.tensor(1), torch.tensor(1)) self._test_remote_message_delay_timeout(func, args) @dist_init( faulty_messages=[], messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1}, ) def test_remote_message_builtin_delay_timeout_to_self(self): func = torch.add args = (torch.tensor(1), torch.tensor(1)) self._test_remote_message_delay_timeout(func, args, dst=0) @dist_init( faulty_messages=[], messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1}, ) def test_remote_message_script_delay_timeout(self): func = my_script_func args = (torch.tensor(1),) self._test_remote_message_delay_timeout(func, args) @dist_init( faulty_messages=[], messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1}, ) def test_remote_message_script_delay_timeout_to_self(self): func = my_script_func args = (torch.tensor(1),) self._test_remote_message_delay_timeout(func, args, dst=0) @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_RREF_FETCH_CALL": 1}) def test_rref_to_here_timeout(self): if self.rank != 0: return dst_rank = (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) rref = rpc.remote( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) expected_error = self.get_timeout_error_regex() with self.assertRaisesRegex(RuntimeError, expected_error): rref.to_here(0.01) rref.to_here() @dist_init(faulty_messages=[]) def test_rpc_builtin_timeout(self): next_rank = (self.rank + 1) % self.world_size dst_worker = worker_name(next_rank) expected_error = self.get_timeout_error_regex() # PYTHON_CALL message types which correspond to Python UDF over RPC # by default get a delay (see faulty_rpc_agent_test_fixture) with self.assertRaisesRegex(RuntimeError, expected_error): rpc.rpc_sync( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=1, ) fut = rpc.rpc_async( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=1 ) with self.assertRaisesRegex(RuntimeError, expected_error): fut.wait() # Ensure that the currently set default timeout is large enough such # that RPCs with delays still complete. fut = rpc.rpc_async( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) fut.wait() # Ensure timeout if we set a new default and don't override rpc._set_rpc_timeout(0.001) fut = rpc.rpc_async( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) with self.assertRaisesRegex(RuntimeError, expected_error): fut.wait() # Ensure run to completion if we specify timeout of 0 fut = rpc.rpc_async( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=0 ) fut.wait() # Reset for clean shutdown rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5}) def test_rpc_script_timeout(self): next_rank = (self.rank + 1) % self.world_size dst_worker = worker_name(next_rank) expected_error = self.get_timeout_error_regex() with self.assertRaisesRegex(RuntimeError, expected_error): rpc.rpc_sync(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1) fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1) with self.assertRaisesRegex(RuntimeError, expected_error): fut.wait() # Ensure that the currently set default timeout is large enough such # that RPCs with delays still complete. fut = rpc.rpc_async( dst_worker, my_script_func, args=(torch.tensor(1),) ) fut.wait() # Ensure timeout if we set a new default and don't override rpc._set_rpc_timeout(0.001) fut = rpc.rpc_async( dst_worker, my_script_func, args=(torch.tensor(1),) ) with self.assertRaisesRegex(RuntimeError, expected_error): fut.wait() # Ensure run to completion if we specify timeout of 0 rpc._set_rpc_timeout(0.001) fut = rpc.rpc_async( dst_worker, my_script_func, args=(torch.tensor(1),), timeout=0 ) fut.wait() # Reset for clean shutdown rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)