import torch import torch.nn.functional as F from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads from .expanded_weights_utils import forward_helper, set_grad_sample_if_exists, \ standard_kwargs, sum_over_all_but_batch_and_last_n, unpack_expanded_weight_or_tensor from typing import List, Optional @implements_per_sample_grads(F.layer_norm) class LayerNormPerSampleGrad(torch.autograd.Function): @staticmethod def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): expanded_args, expanded_kwargs = standard_kwargs(kwarg_names, expanded_args_and_kwargs) input = expanded_args[0] normalized_shape = expanded_args[1] if len(input.shape) <= len(normalized_shape): raise RuntimeError("Expanded Weights: Layer norm should not normalize over batch dimension for per sample gradient" f"computations but got that normalized shape, {normalized_shape}, matched input shape.") output, mean, rstd = forward_helper(torch.native_layer_norm, expanded_args, expanded_kwargs) ctx.args = expanded_args if input.requires_grad or isinstance(expanded_kwargs['weight'], ExpandedWeight): ctx.weight = expanded_kwargs['weight'] if input.requires_grad or isinstance(expanded_kwargs['bias'], ExpandedWeight): ctx.bias = expanded_kwargs['bias'] ctx.eps = expanded_kwargs['eps'] ctx.mean, ctx.rstd = mean, rstd return output @staticmethod def backward(ctx, grad_output): def weight_per_sample_grad(weight): return sum_over_all_but_batch_and_last_n(F.layer_norm(input, normalized_shape, eps=ctx.eps) * grad_output, weight.dim()) input, normalized_shape = ctx.args mean, rstd = ctx.mean, ctx.rstd results: List[Optional[torch.Tensor]] = [] results.append(None) # for kwarg names results.append(None) # for op reference if input.requires_grad: weight_ = unpack_expanded_weight_or_tensor(ctx.weight) bias_ = unpack_expanded_weight_or_tensor(ctx.bias) results.append(torch.ops.aten.native_layer_norm_backward( grad_output, input, normalized_shape, mean, rstd, weight_, bias_, (True, False, False))[0]) else: results.append(None) # weight and bias don't compute batched gradients; no other arguments are differentiable results = results + [None] * 4 # set grad_sample field for weight and bias with per sample gradients if hasattr(ctx, "weight"): set_grad_sample_if_exists(ctx.weight, weight_per_sample_grad) if hasattr(ctx, "bias"): set_grad_sample_if_exists(ctx.bias, lambda bias: sum_over_all_but_batch_and_last_n(grad_output, bias.dim())) return tuple(results)