import torch import torch.distributed as dist from torch import nn from torch.nn.parallel import DistributedDataParallel from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( RpcAgentTestFixture, ) from torch.testing._internal.common_distributed import ( requires_gloo, requires_nccl, skip_if_lt_x_gpu, skip_if_rocm, ) from torch.distributed.pipeline.sync import Pipe class PipeWithDDPTest(RpcAgentTestFixture): @property def world_size(self) -> int: return 2 @skip_if_lt_x_gpu(4) @requires_nccl() @dist_init @skip_if_rocm def test_basic_nccl_ckpt_never(self): self._run_basic_test("nccl", "never") @skip_if_lt_x_gpu(4) @requires_nccl() @dist_init @skip_if_rocm def test_basic_nccl_ckpt_never_find_unused(self): self._run_basic_test("nccl", "never", find_unused_parameters=True) @skip_if_lt_x_gpu(4) @requires_nccl() @dist_init @skip_if_rocm def test_basic_nccl_ckpt_always(self): self._run_basic_test("nccl", "always", static_graph=True) @skip_if_lt_x_gpu(4) @requires_nccl() @dist_init @skip_if_rocm def test_basic_nccl_ckpt_except_last(self): self._run_basic_test("nccl", "except_last", static_graph=True) @skip_if_lt_x_gpu(4) @requires_gloo() @dist_init @skip_if_rocm def test_basic_gloo_ckpt_never(self): self._run_basic_test("gloo", "never") @skip_if_lt_x_gpu(4) @requires_gloo() @dist_init @skip_if_rocm def test_basic_gloo_ckpt_never_find_unused(self): self._run_basic_test("gloo", "never", find_unused_parameters=True) @skip_if_lt_x_gpu(4) @requires_gloo() @dist_init @skip_if_rocm def test_basic_gloo_ckpt_always(self): self._run_basic_test("gloo", "always", static_graph=True) @skip_if_lt_x_gpu(4) @requires_gloo() @dist_init @skip_if_rocm def test_basic_gloo_ckpt_except_last(self): self._run_basic_test("gloo", "except_last", static_graph=True) def _run_basic_test(self, backend, checkpoint, find_unused_parameters=False, static_graph=False): dist.init_process_group( backend=backend, init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), world_size=self.world_size, rank=self.rank, ) # Use 4 GPUs, two replicas of a pipe across GPU 0 and 1 and another # pipe between GPU 2 and 3. Both replicas are replicated via DDP. fc1 = nn.Linear(16, 8, bias=False).cuda(2 * self.rank) class MyModule(nn.Module): def __init__(self, device): super(MyModule, self).__init__() self.fc2 = nn.Linear(8, 4, bias=False).cuda(device) self.fc3 = nn.Linear(4, 2, bias=False).cuda(device) def forward(self, inp): if find_unused_parameters: return self.fc2(inp) else: return self.fc3(self.fc2(inp)) layer2 = MyModule(2 * self.rank + 1) model = nn.Sequential( fc1, layer2 ) model = Pipe(model, chunks=2, checkpoint=checkpoint) model = DistributedDataParallel( model, find_unused_parameters=find_unused_parameters, static_graph=static_graph, ) # Ensure inputs are different across ranks to verify that gradient # sync indeed occurs. model_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) out = model(model_input).local_value() out.sum().backward() # Run forward again for find_unused_parameters to trigger any potential errors. if find_unused_parameters: # Ensure inputs are different across ranks to verify that gradient # sync indeed occurs. unused_param_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) model(unused_param_input).local_value().sum().backward() # Run a few more iterations of fwd + bwd to ensure gradient synchronization # occurs properly across iterations via delay_all_reduce/bucketized allreduce. for _ in range(3): model_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) out = model(model_input).local_value() out.sum().backward() # Check grads output = [torch.empty_like(fc1.weight.grad), torch.empty_like(fc1.weight.grad)] dist.all_gather(output, fc1.weight.grad) self.assertEqual(output[0], output[1]) output = [torch.empty_like(layer2.fc2.weight.grad), torch.empty_like(layer2.fc2.weight.grad)] dist.all_gather(output, layer2.fc2.weight.grad) self.assertEqual(output[0], output[1]) if not find_unused_parameters: output = [torch.empty_like(layer2.fc3.weight.grad), torch.empty_like(layer2.fc3.weight.grad)] dist.all_gather(output, layer2.fc3.weight.grad) self.assertEqual(output[0], output[1])