from typing import Dict, Tuple import torch import torch.distributed.autograd as dist_autograd import torch.distributed.rpc as rpc from torch import Tensor from torch.distributed.rpc import rpc_async from torch.testing import FileCheck from torch.testing._internal.dist_utils import dist_init, worker_name from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( RpcAgentTestFixture, ) @torch.jit.script def local_add(t1, t2): return torch.add(t1, t2) @torch.jit.script def remote_add(t1, t2, dst: str): # noqa: E999 return rpc_async(dst, local_add, (t1, t2)).wait() @torch.jit.script def fork_add(t1, t2, dst: str): fut = torch.jit._fork(remote_add, t1, t2, dst) return torch.jit._wait(fut) class JitDistAutogradTest(RpcAgentTestFixture): @dist_init def test_get_gradients(self): dst_rank = self.rank @torch.jit.script def dist_get_gradients(context_id: int) -> (Dict[Tensor, Tensor]): return dist_autograd.get_gradients(context_id) FileCheck().check("get_gradients").run(str(dist_get_gradients.graph)) with dist_autograd.context() as context_id: t1 = torch.rand((3, 3), requires_grad=True) t2 = torch.rand((3, 3), requires_grad=True) t3 = torch.add(t1, t2) dist_autograd.backward(context_id, [t3.sum()]) grads = dist_get_gradients(context_id) self.assertEqual(2, len(grads)) self.assertIn(t1, grads) self.assertIn(t2, grads) self.assertEqual(torch.ones(3, 3), grads[t1]) self.assertEqual(torch.ones(3, 3), grads[t2]) @dist_init def test_dist_backward(self): if self.rank != 0: return @torch.jit.script def dist_backward_script(context_id: int, loss: torch.Tensor): dist_autograd.backward(context_id, [loss]) FileCheck().check("dist_backward").run(str(dist_backward_script.graph)) with dist_autograd.context() as context_id: t1 = torch.rand(3, 3, requires_grad=True) t2 = torch.rand(3, 3, requires_grad=True) dst_worker_name = worker_name((self.rank + 1) % self.world_size) loss = rpc.rpc_sync(dst_worker_name, torch.add, args=(t1, t2)).sum() dist_backward_script(context_id, loss) @dist_init def test_jit_fork_within_context(self): with dist_autograd.context() as context_id: t1 = torch.rand((3, 3), requires_grad=True) t2 = torch.rand((3, 3), requires_grad=True) dst_worker_name = worker_name((self.rank + 1) % self.world_size) res = fork_add(t1, t2, dst_worker_name) loss = res.sum() dist_autograd.backward(context_id, [loss]) grads = dist_autograd.get_gradients(context_id) self.assertEqual(2, len(grads)) self.assertIn(t1, grads) self.assertIn(t2, grads) @dist_init def test_restore_context_after_swtich_to_jit_thread(self): if self.rank != 0: return @torch.jit.script def forward_script( context_id: int, dst_worker_name: str, t1: Tensor, t2: Tensor ) -> Tuple[Tensor, Tensor]: res1_fut = rpc.rpc_async(dst_worker_name, local_add, (t1, t1)) res1 = res1_fut.wait() # After this, the script runs in a new JIT thread. loss1 = res1.sum() # SendRpcBackward is not attched, since DistAutogradContext is lost here. res2_fut = rpc.rpc_async(dst_worker_name, local_add, (t2, t2)) res2 = res2_fut.wait() loss2 = res2.sum() return loss1, loss2 with dist_autograd.context() as context_id: t1 = torch.ones((2, 3), requires_grad=True) t2 = torch.ones((2, 3), requires_grad=True) dst_worker_name = worker_name((self.rank + 1) % self.world_size) loss0, loss1 = forward_script(context_id, dst_worker_name, t1, t2) dist_autograd.backward(context_id, [loss0, loss1]) grad0, grad1 = dist_autograd.get_gradients(context_id) self.assertEqual(grad0, grad1)