import torch from torch.distributed._shard.replicated_tensor import ReplicatedTensor class ReplicatedTensorFunction(torch.autograd.Function): """ Autograd function to ensure gradients are replicated between the replicated tensor and the original one. """ @staticmethod def forward(ctx, inp, process_group=None): # set_materialize_grads(False) will ensure that None gradients stay as # None and are not filled with zeros. ctx.set_materialize_grads(False) return ReplicatedTensor(inp, process_group) @staticmethod def backward(ctx, grad_output): return grad_output, None def _make_replicated_tensor(tensor, process_group): replicated_tensor = ReplicatedTensorFunction.apply(tensor, process_group) replicated_tensor.grad = tensor.grad return replicated_tensor def _replicate_module_recurse(module, process_group): replica = module._replicate_for_data_parallel() for param_name, param in module._parameters.items(): if param is not None: setattr(replica, param_name, _make_replicated_tensor(param, process_group)) else: setattr(replica, param_name, param) for buffer_name, buffer in module._buffers.items(): setattr(replica, buffer_name, buffer) for module_name, child in module._modules.items(): setattr(replica, module_name, _replicate_module_recurse(child, process_group)) return replica def _replicate_module(network, process_group): from torch.nn.parallel.replicate import _replicatable_module # type: ignore[attr-defined] if not _replicatable_module(network): raise RuntimeError("Cannot replicate network where python modules are " "childrens of ScriptModule") return _replicate_module_recurse(network, process_group)