#!/usr/bin/python3 import enum from typing import Tuple import torch import torch.distributed.rpc as rpc import torch.testing._internal.dist_utils as dist_utils from torch import Tensor, nn from torch._jit_internal import Future from torch.distributed.nn import RemoteModule from torch.distributed.nn.api.remote_module import _REMOTE_MODULE_PICKLED_ATTRIBUTES from torch.distributed.nn.api.remote_module import _RemoteModule from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import TemporaryFileName from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( RpcAgentTestFixture, ) _PARAM_VAL = torch.nn.Parameter(torch.ones(1)) # RPC handler for querying the device on the destination worker. def remote_device(module_rref): for param in module_rref.local_value().parameters(): return param.device # RPC handler for querying __dict__ on the destination worker. def remote_module_attributes(remote_module): return remote_module.__dict__ # RPC handler for running forward on the destination worker. def remote_forward(remote_module, args): return remote_module.forward(*args) # RPC handler for running forward_async on the destination worker. def remote_forward_async(remote_module, args): # Since future cannot be pickled and sent over the RPC layer, # have to wait and behave just like ``forward_sync``. return remote_module.forward_async(*args).wait() # RPC handler for getting training mode on the destination worker. def get_remote_training_arg(module_rref): return module_rref.local_value().training class ModuleCreationMode(enum.Enum): MODULE_CTOR_WITH_INTERFACE = "module_ctor_with_interface" MODULE_CTOR = "module_ctor" @torch.jit.interface class MyModuleInterface: def forward( self, tensor: Tensor, number: int, word: str = "default" ) -> Tuple[str, int, Tensor]: # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well pass @torch.jit.interface class RemoteMyModuleInterface: def forward( self, tensor: Tensor, number: int, word: str = "default" ) -> Tuple[str, int, Tensor]: # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well pass def forward_async( self, tensor: Tensor, number: int, word: str = "default" ) -> Future[Tuple[str, int, Tensor]]: pass class MyModule(nn.Module): def __init__(self, first_arg, first_kwarg=-1): super().__init__() self.param1 = _PARAM_VAL def forward( self, tensor: Tensor, number: int, word: str = "default" ) -> Tuple[str, int, Tensor]: return word, number, tensor class BadModule: def __init__(self, first_arg, first_kwarg=-1): pass def create_scripted_module(first_arg, first_kwarg=-1): module = MyModule(first_arg, first_kwarg=first_kwarg) scripted_module = torch.jit.script(module) return scripted_module # Common utils for both CPU and CUDA test suites class CommonRemoteModuleTest(RpcAgentTestFixture): @property def world_size(self): # Override setting in RpcAgentTestFixture return 2 @staticmethod def _create_remote_module_iter(remote_device, modes=None): if modes is None: modes = ModuleCreationMode.__members__.values() args = (1,) kwargs = dict(first_kwarg=2) if ModuleCreationMode.MODULE_CTOR in modes: remote_module = RemoteModule(remote_device, MyModule, args, kwargs) yield remote_module if ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE in modes: remote_module = _RemoteModule( remote_device, create_scripted_module, args, kwargs, _module_interface_cls=MyModuleInterface, ) scripted_remote_module = torch.jit.script(remote_module) yield scripted_remote_module class RemoteModuleTest(CommonRemoteModuleTest): @dist_utils.dist_init def test_bad_module(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) remote_device = "{}/cpu".format(dst_worker_name) args = (1,) kwargs = dict(first_kwarg=2) with self.assertRaisesRegex( ValueError, r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of ,", ): RemoteModule(remote_device, BadModule, args, kwargs).forward() with self.assertRaisesRegex( ValueError, r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of ,", ): RemoteModule(remote_device, BadModule, args, kwargs).forward() @dist_utils.dist_init def test_forward_async(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) args = (torch.ones(1), 2, "3") for remote_module in self._create_remote_module_iter(dst_worker_name): ret_fut = remote_module.forward_async(*args) ret = ret_fut.wait() self.assertEqual(ret, tuple(reversed(args))) @dist_utils.dist_init def test_forward_async_script(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) scripted_remote_module = next( self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE] ) ) @torch.jit.script def run_forward_async(scripted_remote_module: RemoteMyModuleInterface): ret_fut = scripted_remote_module.forward_async(torch.ones(1), 2, "3") ret = ret_fut.wait() return ret ret = run_forward_async(scripted_remote_module) self.assertEqual(ret, ("3", 2, torch.ones(1))) @dist_utils.dist_init def test_forward_sync(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) args = (torch.ones(1), 2, "3") for remote_module in self._create_remote_module_iter(dst_worker_name): ret = remote_module.forward(*args) self.assertEqual(ret, tuple(reversed(args))) @dist_utils.dist_init def test_forward_sync_script(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) scripted_remote_module = next( self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE] ) ) @torch.jit.script def run_forward(scripted_remote_module: MyModuleInterface): ret = scripted_remote_module.forward(torch.ones(1), 2, "3") return ret ret = run_forward(scripted_remote_module) self.assertEqual(ret, ("3", 2, torch.ones(1))) @dist_utils.dist_init def test_forward_with_kwargs(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) args = (torch.ones(1), 2) kwargs = dict(word="3") # Only test Python nn.Module, because script module methods don't support taking kwargs. for remote_module in self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] ): ret_fut = remote_module.forward_async(*args, **kwargs) ret = ret_fut.wait() self.assertEqual(ret, tuple(reversed(args + ("3",)))) ret = remote_module.forward(*args, **kwargs) self.assertEqual(ret, tuple(reversed(args + ("3",)))) @dist_utils.dist_init def test_remote_parameters(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) # Only test Python nn.Module, because script module methods don't support ``remote_parameters``. for remote_module in self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] ): param_rrefs = remote_module.remote_parameters() self.assertEqual(len(param_rrefs), 1) self.assertTrue(torch.equal(param_rrefs[0].to_here(), _PARAM_VAL)) @dist_utils.dist_init def test_get_module_rref(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) # Only test Python nn.Module, because script module methods don't support ``get_module_rref``. for remote_module in self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] ): rref = remote_module.get_module_rref() self.assertEqual(rref, remote_module.module_rref) for param in rref.to_here().parameters(): self.assertTrue(torch.equal(param, _PARAM_VAL)) @dist_utils.dist_init def test_train_eval(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) for remote_module in self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] ): remote_module.train() ret1 = rpc.rpc_sync(dst_worker_name, get_remote_training_arg, args=(remote_module.get_module_rref(),)) self.assertEqual(ret1, True) remote_module.eval() ret2 = rpc.rpc_sync(dst_worker_name, get_remote_training_arg, args=(remote_module.get_module_rref(),)) self.assertEqual(ret2, False) @dist_utils.dist_init def test_unsupported_methods(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) for remote_module in self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] ): with self.assertRaisesRegex( ValueError, r"Method ``register_buffer`` not supported for RemoteModule" ): remote_module.register_buffer("buffer", torch.ones(5)) with self.assertRaisesRegex( ValueError, r"Method ``register_parameter`` not supported for RemoteModule", ): remote_module.register_parameter( "param", torch.nn.Parameter(torch.ones(1)) ) with self.assertRaisesRegex( ValueError, r"Method ``add_module`` not supported for RemoteModule" ): remote_module.add_module("empty", None) with self.assertRaisesRegex( ValueError, r"Method ``apply`` not supported for RemoteModule" ): fn = torch.rand((3, 3), requires_grad=False) remote_module.apply(fn) with self.assertRaisesRegex( ValueError, r"Method ``cuda`` not supported for RemoteModule" ): remote_module.cuda() with self.assertRaisesRegex( ValueError, r"Method ``cpu`` not supported for RemoteModule" ): remote_module.cpu() with self.assertRaisesRegex( ValueError, r"Method ``type`` not supported for RemoteModule" ): remote_module.type(torch.FloatTensor) with self.assertRaisesRegex( ValueError, r"Method ``float`` not supported for RemoteModule" ): remote_module.float() with self.assertRaisesRegex( ValueError, r"Method ``double`` not supported for RemoteModule" ): remote_module.double() with self.assertRaisesRegex( ValueError, r"Method ``bfloat16`` not supported for RemoteModule" ): remote_module.bfloat16() with self.assertRaisesRegex( ValueError, r"Method ``to`` not supported for RemoteModule" ): remote_module.to("cpu", dtype=torch.int32) def hook(module, grad_input, grad_output): pass with self.assertRaisesRegex( ValueError, r"Method ``register_backward_hook`` not supported for RemoteModule", ): remote_module.register_backward_hook(hook) with self.assertRaisesRegex( ValueError, r"Method ``register_forward_pre_hook`` not supported for RemoteModule", ): remote_module.register_forward_pre_hook(hook) with self.assertRaisesRegex( ValueError, r"Method ``register_forward_hook`` not supported for RemoteModule", ): remote_module.register_forward_hook(hook) with self.assertRaisesRegex( ValueError, r"Method ``state_dict`` not supported for RemoteModule" ): remote_module.state_dict() with self.assertRaisesRegex( ValueError, r"Method ``load_state_dict`` not supported for RemoteModule" ): remote_module.load_state_dict({}) with self.assertRaisesRegex( ValueError, r"Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead.", ): remote_module.parameters() with self.assertRaisesRegex( ValueError, r"Method ``named_parameters`` not supported for RemoteModule", ): remote_module.named_parameters() with self.assertRaisesRegex( ValueError, r"Method ``buffers`` not supported for RemoteModule" ): remote_module.buffers() with self.assertRaisesRegex( ValueError, r"Method ``named_buffers`` not supported for RemoteModule" ): remote_module.named_buffers() with self.assertRaisesRegex( ValueError, r"Method ``children`` not supported for RemoteModule" ): remote_module.children() with self.assertRaisesRegex( ValueError, r"Method ``named_children`` not supported for RemoteModule" ): remote_module.named_children() with self.assertRaisesRegex( ValueError, r"Method ``modules`` not supported for RemoteModule" ): remote_module.modules() with self.assertRaisesRegex( ValueError, r"Method ``named_modules`` not supported for RemoteModule" ): remote_module.named_modules() with self.assertRaisesRegex( ValueError, r"Method ``requires_grad_`` not supported for RemoteModule" ): remote_module.requires_grad_() with self.assertRaisesRegex( ValueError, r"Method ``zero_grad`` not supported for RemoteModule" ): remote_module.zero_grad() with self.assertRaisesRegex( ValueError, r"Method ``share_memory`` not supported for RemoteModule" ): remote_module.share_memory() with self.assertRaisesRegex( ValueError, r"Method ``extra_repr`` not supported for RemoteModule" ): remote_module.extra_repr() @dist_utils.dist_init def test_send_remote_module_with_a_new_attribute_not_pickled_over_the_wire(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) # If a new attribute is added to this RemoteModule after the initialization, # and it will be sent over the wire by RPC, # this new field will not be pickled, because it's not specified in _REMOTE_MODULE_PICKLED_ATTRIBUTES. # Note that adding a new attribute out of constructor should rarely happen. # If a new attribute is added to RemoteModule constructor, # there is a sanity check to enforce developers to add this attribute to either # _REMOTE_MODULE_PICKLED_ATTRIBUTES or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING. for remote_module in self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] ): new_attr_name = "new_attr" setattr(remote_module, new_attr_name, 1) attrs = rpc.rpc_sync( dst_worker_name, remote_module_attributes, (remote_module,) ) self.assertNotIn(new_attr_name, attrs) @dist_utils.dist_init def test_remote_module_py_pickle_not_supported(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) for remote_module in self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] ): with TemporaryFileName() as fname: with self.assertRaisesRegex( RuntimeError, "Cannot pickle RemoteModule in python pickler. RemoteModule can only be pickled when using RPC", ): torch.save(remote_module, fname) @dist_utils.dist_init def test_remote_module_py_pickle_not_supported_script(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) for remote_module in self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE] ): with TemporaryFileName() as fname: with self.assertRaisesRegex(torch.jit.Error, "can only be pickled when using RPC"): torch.save(remote_module, fname) class ThreeWorkersRemoteModuleTest(CommonRemoteModuleTest): @property def world_size(self): # Override setting in CommonRemoteModuleTest return 3 @dist_utils.dist_init def test_send_remote_module_over_the_wire(self): if self.rank != 0: return dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size) dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size) # Unpickled attribtes include both the inherent attributes of RemoteModule # (not inherited from the superclass) and two installed methods. expected_unpickled_attrs = list(_REMOTE_MODULE_PICKLED_ATTRIBUTES) expected_unpickled_attrs.append("forward_async") expected_unpickled_attrs.append("forward") # Create a remote module on worker1 and then pass it to worker2 over the RPC layer. for remote_module in self._create_remote_module_iter( dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR] ): # Test querying some simple attributes from worker2. attrs = rpc.rpc_sync( dst_worker2_name, remote_module_attributes, (remote_module,) ) self.assertListEqual(list(attrs.keys()), expected_unpickled_attrs) self.assertEqual(attrs["on"], "worker1") self.assertEqual(attrs["device"], "cpu") self.assertFalse(attrs["is_device_map_set"]) self.assertFalse(attrs["is_scriptable"]) # Test the installed methods on worker1's can be initiated by worker2 over RPC layer. # NOTE: In practice a remote module should be directly stored on the worker that runs ``forward``` or ``forward_async``, # not have another worker to initiate forward over the RPC layer. args = (torch.ones(1), 2, "3") ret1 = rpc.rpc_sync(dst_worker2_name, remote_forward, (remote_module, args)) self.assertEqual(ret1, tuple(reversed(args))) ret2 = rpc.rpc_sync( dst_worker2_name, remote_forward_async, (remote_module, args) ) self.assertEqual(ret2, tuple(reversed(args))) @dist_utils.dist_init def test_send_remote_module_over_the_wire_script_not_supported(self): if self.rank != 0: return dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size) dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size) # Unpickled attribtes include both the inherent attributes of RemoteModule # (not inherited from the superclass) and two installed methods. expected_unpickled_attrs = list(_REMOTE_MODULE_PICKLED_ATTRIBUTES) expected_unpickled_attrs.append("forward_async") expected_unpickled_attrs.append("forward") with self.assertRaisesRegex( RuntimeError, "Passing a script RemoteModule over RPC is not supported." ): # Create a remote module on worker1 and then pass it to worker2 over the RPC layer. for remote_module in self._create_remote_module_iter( dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE] ): # Test querying some simple attributes from worker2. attrs = rpc.rpc_sync( dst_worker2_name, remote_module_attributes, (remote_module,) ) @dist_utils.dist_init def test_create_remote_module_from_module_rref(self): if self.rank != 0: return dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size) dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size) # Create a remote module on worker1 and then pass its `module_rref` to worker2 over the RPC layer. for remote_module in self._create_remote_module_iter( dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR] ): remote_module2 = rpc.rpc_sync( dst_worker2_name, RemoteModule.init_from_module_rref, (dst_worker2_name, remote_module.get_module_rref()), ) args = (torch.ones(1), 2, "3") ret1 = rpc.rpc_sync( dst_worker1_name, remote_forward, (remote_module, args) ) ret2 = rpc.rpc_sync( dst_worker2_name, remote_forward, (remote_module2, args) ) self.assertEqual(ret2, ret2) class CudaRemoteModuleTest(CommonRemoteModuleTest): @skip_if_lt_x_gpu(1) @dist_utils.dist_init def test_valid_device(self): if self.rank != 0: return dst_rank = (self.rank + 1) % self.world_size dst_worker_name = dist_utils.worker_name(dst_rank) for remote_module in self._create_remote_module_iter( "{}/cuda:0".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR] ): device = rpc.rpc_sync( dst_worker_name, remote_device, (remote_module.module_rref,) ) self.assertEqual(device.type, "cuda") self.assertEqual(device.index, 0) # Test rank works as well. for remote_module in self._create_remote_module_iter( "rank:{}/cuda:0".format(dst_rank), modes=[ModuleCreationMode.MODULE_CTOR] ): device = rpc.rpc_sync( dst_worker_name, remote_device, (remote_module.module_rref,) ) self.assertEqual(device.type, "cuda") self.assertEqual(device.index, 0) @skip_if_lt_x_gpu(1) @dist_utils.dist_init def test_invalid_devices(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) with self.assertRaisesRegex( RuntimeError, r"Expected one of .+ device type at start of device string", ): list( m.forward() for m in self._create_remote_module_iter( "{}/foo".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], ) ) with self.assertRaisesRegex( RuntimeError, r"CUDA error: invalid device ordinal" ): list( m.forward() for m in self._create_remote_module_iter( "{}/cuda:100".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], ) ) with self.assertRaisesRegex(RuntimeError, r"Invalid device string: 'cpu2'"): list( m.forward() for m in self._create_remote_module_iter( "{}/cpu2".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], ) ) with self.assertRaisesRegex(RuntimeError, r"Device string must not be empty"): list( m.forward() for m in self._create_remote_module_iter( "{}/".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], ) ) with self.assertRaisesRegex( ValueError, r"Could not parse remote_device: worker1/cuda:0/cuda:1. The valid format is '/'", ): list( m.forward() for m in self._create_remote_module_iter( "{}/cuda:0/cuda:1".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], ) ) with self.assertRaisesRegex( ValueError, r"Could not parse remote_device: /. The valid format is '/'", ): list( m.forward() for m in self._create_remote_module_iter( "/", modes=[ModuleCreationMode.MODULE_CTOR], ) ) with self.assertRaisesRegex( ValueError, r"Could not parse remote_device: /cuda:0. The valid format is '/'", ): list( m.forward() for m in self._create_remote_module_iter( "/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR], ) ) @skip_if_lt_x_gpu(1) @dist_utils.dist_init def test_input_moved_to_cuda_device(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) # These two CPU tensors (in args and kwargs) should be implicitly moved to an appropriate cuda device. t1 = torch.ones(1) args = (t1, 2) t2 = t1 * 2 kwargs = dict(word=t2) # Only test Python nn.Module, because script module methods don't support taking kwargs. for remote_module in self._create_remote_module_iter( "{}/cuda:0".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR] ): ret_fut = remote_module.forward_async(*args, **kwargs) ret = ret_fut.wait() self.assertEqual(ret, tuple(reversed(args + (t2,)))) # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0". self.assertEqual(ret[0].device.type, "cpu") self.assertEqual(ret[2].device.type, "cpu") ret = remote_module.forward(*args, **kwargs) self.assertEqual(ret, tuple(reversed(args + (t2,)))) # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0". self.assertEqual(ret[0].device.type, "cpu") self.assertEqual(ret[2].device.type, "cpu") @skip_if_lt_x_gpu(1) @dist_utils.dist_init def test_input_moved_to_cuda_device_script(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) scripted_remote_module = next( self._create_remote_module_iter( "{}/cuda:0".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE], ) ) @torch.jit.script def run_forward(scripted_remote_module: MyModuleInterface): ret = scripted_remote_module.forward(torch.ones(1), 2, "3") return ret ret = run_forward(scripted_remote_module) self.assertEqual(ret, ("3", 2, torch.ones(1))) # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0". self.assertEqual(ret[2].device.type, "cpu")