from typing import Dict, Tuple import torch import torch.distributed.rpc as rpc from torch import Tensor from torch.distributed.rpc import RRef from torch.testing._internal.dist_utils import ( dist_init, worker_name, wait_until_pending_futures_and_users_flushed ) from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( RpcAgentTestFixture, ) @torch.jit.script def two_args_two_kwargs( first_arg, second_arg, first_kwarg=torch.tensor([3, 3]), second_kwarg=torch.tensor([4, 4]), ): return first_arg + second_arg + first_kwarg + second_kwarg @torch.jit.script def script_rpc_async_call( dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor] ): fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) ret = fut.wait() return ret @torch.jit.script def rpc_async_call_with_timeout( dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor], timeout: float, ): fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout) ret = fut.wait() return ret @torch.jit.script def rpc_async_call_with_timeout_future_ret( dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor], timeout: float, ): fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout) return fut @torch.jit.script def rpc_async_call_future_ret( dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor] ): fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) return fut @torch.jit.script def rref_to_here(rref_var: RRef[Tensor]) -> Tensor: return rref_var.to_here() @torch.jit.script def rref_to_here_with_timeout(rref_var: RRef[Tensor], timeout: float) -> Tensor: return rref_var.to_here(timeout) @torch.jit.script def rpc_async_with_rref_arg(dst_worker_name: str, args: Tuple[RRef[Tensor]]) -> Tensor: fut = rpc.rpc_async(dst_worker_name, rref_to_here, args) ret = fut.wait() return ret class JitFaultyAgentRpcTest(RpcAgentTestFixture): """ Run tests for rpc_async in JIT under the faulty agent test fixture to test arbitrary timeouts. """ @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5}) def test_timeout_in_torchscript_function(self): # Call rpc_async + fut.wait() in torchscript function and ensure that # timeout is raised. if self.rank != 0: return dst_worker_name = worker_name((self.rank + 1) % self.world_size) args = (torch.tensor([1, 1]), torch.tensor([2, 2])) kwargs = { "first_kwarg": torch.tensor([2, 2]), "second_kwarg": torch.tensor([3, 3]), } expected_error = self.get_timeout_error_regex() # Ensure that we get a timeout if we override the default timeout and # the RPC takes longer to execute. with self.assertRaisesRegex(RuntimeError, expected_error): rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0.5) # Ensure that we timeout if we don't specify a timeout but the default # is less than the RPC takes to execute. rpc._set_rpc_timeout(0.001) with self.assertRaisesRegex(RuntimeError, expected_error): script_rpc_async_call( dst_worker_name, args, kwargs ) # Ensure that we run to completion if zero timeout is specified. ret = rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0) self.assertEqual(ret, torch.tensor([8, 8])) # reset for clean shutdown rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5}) def test_timeout_in_python(self): # Ensures timeouts are raised if we call rpc_async from within a # torchscript function, but wait on the future in python. if self.rank != 0: return dst_worker_name = worker_name((self.rank + 1) % self.world_size) args = (torch.tensor([1, 1]), torch.tensor([2, 2])) kwargs = { "first_kwarg": torch.tensor([2, 2]), "second_kwarg": torch.tensor([3, 3]), } expected_error = self.get_timeout_error_regex() fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0.5) with self.assertRaisesRegex(RuntimeError, expected_error): fut.wait() # Ensure timeout if we don't specify but the default is less than the # RPC takes to execute. rpc._set_rpc_timeout(0.001) fut = rpc_async_call_future_ret(dst_worker_name, args, kwargs) with self.assertRaisesRegex(RuntimeError, expected_error): fut.wait() # Ensure run to completion if zero timeout is specified fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0) result = fut.wait() self.assertEqual(result, torch.tensor([8, 8])) # reset for clean shutdown rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"]) def test_remote_timeout_to_here_in_jit(self): # Test that calling to_here() in JIT will raise timeout error if # rpc.remote failed. if self.rank != 0: return dst_rank = (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) rref = rpc.remote( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) # Will ensure error handling callbacks are run. wait_until_pending_futures_and_users_flushed() # Call to_here() within a ScriptFunction and ensure it raises with self.assertRaisesRegex(RuntimeError, "RRef creation"): rref_to_here(rref) @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_RREF_FETCH_CALL": 1}) def test_rref_to_here_timeout_in_jit(self): if self.rank != 0: return dst_rank = (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) rref = rpc.remote( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) expected_error = self.get_timeout_error_regex() with self.assertRaisesRegex(RuntimeError, expected_error): rref_to_here_with_timeout(rref, 0.01) rref_to_here_with_timeout(rref, 100) @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"]) def test_rref_timeout_pickle_in_jit(self): if self.rank != 0: return dst_rank = (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) rref = rpc.remote( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) # Will ensure error handling callbacks are run. wait_until_pending_futures_and_users_flushed() # Call RPC with RRef arg in JIT, which will go through JIT pickling and # ensure error is raised. with self.assertRaisesRegex(RuntimeError, "RRef creation"): rpc_async_with_rref_arg(dst_worker, (rref, )) @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"]) def test_rref_timeout_pickle_script_func(self): # Similar to above test, but calls python rpc with script function. if self.rank != 0: return dst_rank = (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) rref = rpc.remote( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) # Will ensure error handling callbacks are run. wait_until_pending_futures_and_users_flushed() # Call RPC with script function that takes RRef, ensure timeout during pickling with self.assertRaisesRegex(RuntimeError, "RRef creation"): rpc.rpc_sync(dst_worker, rref_to_here, args=(rref, ))