# If you need to modify this file to make this test pass, please also apply same edits accordingly to # https://github.com/pytorch/examples/blob/master/distributed/rpc/batch/parameter_server.py # and https://pytorch.org/tutorials/intermediate/rpc_async_execution.html#batch-updating-parameter-server import threading from datetime import datetime from time import perf_counter import torch import torch.distributed.rpc as rpc import torch.nn as nn from torch import optim from torch.testing._internal.dist_utils import ( dist_init, worker_name, ) from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import RpcAgentTestFixture batch_size = 20 in_features = 100 out_features = 30 num_batches = 4 def timed_log(text): print(f"{datetime.now().strftime('%H:%M:%S')} {text}") class BatchUpdateParameterServer(object): def __init__(self, batch_update_size): self.model = nn.Linear(in_features, out_features) self.lock = threading.Lock() self.future_model = torch.futures.Future() self.batch_update_size = batch_update_size self.curr_update_size = 0 self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) for p in self.model.parameters(): p.grad = torch.zeros_like(p) def get_model(self): return self.model @staticmethod @rpc.functions.async_execution def update_and_fetch_model(ps_rref, grads): self = ps_rref.local_value() for p, g in zip(self.model.parameters(), grads): p.grad += g with self.lock: timed_log(f"PS got {self.curr_update_size}/{self.batch_update_size} updates") self.curr_update_size += 1 fut = self.future_model if self.curr_update_size >= self.batch_update_size: for p in self.model.parameters(): p.grad /= self.batch_update_size self.curr_update_size = 0 self.optimizer.step() self.optimizer.zero_grad() fut.set_result(self.model) timed_log("PS updated model") self.future_model = torch.futures.Future() return fut class Trainer(object): def __init__(self, ps_rref): self.ps_rref = ps_rref self.loss_fn = nn.L1Loss() def get_next_batch(self): for _ in range(num_batches): inputs = torch.randn(batch_size, in_features) labels = torch.zeros(batch_size, out_features) yield inputs, labels def train(self): name = rpc.get_worker_info().name m = self.ps_rref.rpc_sync().get_model() for inputs, labels in self.get_next_batch(): timed_log(f"{name} processing one batch") self.loss_fn(m(inputs), labels).backward() timed_log(f"{name} reporting grads") m = rpc.rpc_sync( self.ps_rref.owner(), BatchUpdateParameterServer.update_and_fetch_model, args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]), ) timed_log(f"{name} got updated model") def run_trainer(ps_rref): trainer = Trainer(ps_rref) trainer.train() def run_ps(trainers): timed_log("Start training") start = perf_counter() ps_rref = rpc.RRef(BatchUpdateParameterServer(len(trainers))) futs = [] for trainer in trainers: futs.append( rpc.rpc_async(trainer, run_trainer, args=(ps_rref,)) ) torch.futures.wait_all(futs) stop = perf_counter() timed_log("Finish training") timed_log(f"Time spent training: {stop-start}s") class ParameterServerTest(RpcAgentTestFixture): @dist_init(setup_rpc=False) def test_batch_updating_parameter_server(self): if self.rank != 0: rpc.init_rpc( name=worker_name(self.rank), backend=self.rpc_backend, rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, ) else: rpc.init_rpc( name=worker_name(self.rank), backend=self.rpc_backend, rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, ) run_ps([f"{worker_name(r)}" for r in range(1, self.world_size)]) rpc.shutdown()