import functools import torch import torch.distributed as dist class DefaultState(object): r""" Stores state needed to perform the default communication algorithm within a communication hook. Args: process_group (ProcessGroup): The process group to be used. """ __slots__ = [ "process_group", "world_size", "gradient_predivide_factor", "gradient_postdivide_factor" ] def __init__( self, process_group: dist.ProcessGroup ): if process_group is None: raise ValueError(f"Expected to pass in an explicit ProcessGroup to {self}.") self.process_group = process_group self.world_size = dist.get_world_size(process_group) # Setting two factors `self.gradient_predivide_factor` # and `self.gradient_postdivide_factor` to avoid underflow and overflow self.gradient_predivide_factor = self._get_gradient_predivide_factor( self.world_size ) self.gradient_postdivide_factor = self.world_size / self.gradient_predivide_factor def _get_gradient_predivide_factor(self, world_size: int) -> float: factor: int = 1 while world_size % factor == 0 and world_size / factor > factor: factor *= 2 return float(factor) class LowPrecisionState(DefaultState): r""" Stores state needed to perform gradient communication in a lower precision within a communication hook. Communication hook will cast gradients back to the original parameter precision specified by ``parameter_type`` (default: torch.float32). Builds on top of the :class:`DefaultState`. Args: parameter_type (torch.dtype): The precision of model's parameters. Required for a hook to cast gradients back to a parameter's precision. """ __slots__ = [ "parameter_type", ] def __init__( self, process_group, parameter_type=torch.float32, ): super().__init__(process_group) self.parameter_type = parameter_type def _decompress(state: LowPrecisionState, grad: torch.Tensor): """ Casts gradients back to full parameter precision so that further computation happens in full precision. """ orig_grad_data = grad.data grad.data = grad.data.to(state.parameter_type) # Don't let this memory get reused until after the transfer. orig_grad_data.record_stream(torch.cuda.current_stream()) # type: ignore[arg-type] def allreduce_hook(state: DefaultState, grad: torch.Tensor): r""" This FSDP communication hook implements ``all_reduce`` algorithm and a necessary pre- and post-division of gradients. Args: state (DefaultState): State information, configures pre- and post-division factors. grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks. """ # Average grad by pre-division factor. Together pre- and post-division factors # lead to an overall averaging by world_size, required for consistency with PyTorch DDP. # This is a two-step process to avoid potential underflow and overflow. if state.gradient_predivide_factor > 1: grad.div_(state.gradient_predivide_factor) dist.all_reduce(grad, group=state.process_group) # Average grad by post-division factor. if state.gradient_postdivide_factor > 1: grad.div_(state.gradient_postdivide_factor) def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor, output: torch.Tensor): r""" This FSDP communication hook implements ``reduce_scatter`` algorithm for sharded FSDP strategies and a necessary pre- and post-division of gradients. Args: state (DefaultState): State information, configures pre- and post-division factors. grad (torch.Tensor): An unsharded gradient for the local batch that needs to be communicated across ranks. output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. """ # Average grad by pre-division factor. if state.gradient_predivide_factor > 1: grad.div_(state.gradient_predivide_factor) dist._reduce_scatter_base( output, grad, group=state.process_group ) # Average grad's shard by post-division factor. if state.gradient_postdivide_factor > 1: output.div_(state.gradient_postdivide_factor) def _low_precision_hook(prec: torch.dtype, state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor): grad.data = grad.data.to(prec) if output is not None: output.data = output.data.to(prec) reduce_scatter_hook(state, grad, output) _decompress(state, output) else: allreduce_hook(state, grad) _decompress(state, grad) def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor = None): r""" This FSDP communication hook implements a simple gradient compression approach that casts ``grad`` to half-precision floating-point format (``torch.float16``). It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``) gradients are averaged by a ``state.gradient_postdivide_factor``. Once post-division is done, compressed gradients are casted back to parameters' precision. Args: state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision. grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision. output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. """ fp16_hook = functools.partial(_low_precision_hook, torch.float16) return fp16_hook(state, grad, output) def bf16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor = None): r""" This FSDP communication hook implements a simple gradient compression approach that casts ``grad`` to half-precision floating-point format (``torch.float16``). It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``) gradients are averaged by a ``state.gradient_postdivide_factor``. Once post-division is done, compressed gradients are casted back to parameters' precision. Args: state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision. grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision. output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. """ bf16_hook = functools.partial(_low_precision_hook, torch.bfloat16) return bf16_hook(state, grad, output)