from typing import Optional, Union import torch import torch._prims as prims import torch._prims_common as utils import torch._refs as refs from torch._decomp import register_decomposition from torch._decomp.decompositions import Reduction from torch._prims_common import ( check, ELEMENTWISE_TYPE_PROMOTION_KIND, NumberType, ShapeType, TensorLike, TensorLikeType, ) from torch._prims_common.wrappers import ( elementwise_type_promotion_wrapper, elementwise_unary_scalar_wrapper, out_wrapper, ) from torch._refs import ( _make_elementwise_binary_reference, _make_elementwise_unary_reference, ) __all__ = [ "celu", "dropout", "elu", "hardshrink", "hardtanh", "hinge_embedding_loss", "huber_loss", "l1_loss", "margin_ranking_loss", "mish", "mse_loss", "poisson_nll_loss", "prelu", "relu", "relu6", "selu", "softplus", "softshrink", "tanhshrink", "threshold", "glu", "pairwise_distance", "pdist", ] Tensor = torch.Tensor # celu is implemented specially because it has an alpha argument # celu is very similar to elu @register_decomposition(torch.ops.aten.celu) @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) def celu( a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False ) -> TensorLikeType: """ Reference implementation of torch.nn.functional.celu """ if inplace: raise NotImplementedError rhs: TensorLikeType if alpha is not None: python_type = utils.dtype_to_type(a.dtype) if not utils.is_weakly_lesser_type(type(alpha), python_type): msg = ( "alpha argument of type {0} cannot be safely cast to type {1}!".format( type(alpha), python_type ) ) raise ValueError(msg) rhs = alpha * torch.expm1(torch.true_divide(a, alpha)) # type: ignore[arg-type] else: rhs = torch.expm1(a) return torch.where(a > 0, a, rhs) # TODO: should we allow the user to set a different dtype for the mask generation? @register_decomposition(torch.ops.aten.dropout) def dropout( a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False ) -> TensorLikeType: if inplace: raise NotImplementedError if not training: return a assert p <= 1 assert p >= 0 if p == 1: return refs.zeros_like(a) if p == 0: return a p1m = 1 - p scale = 1 / p1m mask = refs.lt( refs.uniform(a.shape, low=0.0, high=1.0, dtype=torch.float32, device=a.device), p1m, ) return refs.mul(refs.mul(a, mask), scale) # elu is implemented specially because it has an alpha argument # This cannot be used as a decomposition because the aten op takes in 2 extra kwargs @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) def elu( a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False ) -> TensorLikeType: """ Reference implementation of torch.nn.functional.elu """ if inplace: raise NotImplementedError rhs: TensorLikeType if alpha is not None: python_type = utils.dtype_to_type(a.dtype) if not utils.is_weakly_lesser_type(type(alpha), python_type): msg = ( "alpha argument of type {0} cannot be safely cast to type {1}!".format( type(alpha), python_type ) ) raise ValueError(msg) rhs = alpha * torch.expm1(a) else: rhs = torch.expm1(a) return torch.where(a > 0, a, rhs) @register_decomposition(torch.ops.aten.relu) @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) def relu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.relu """ if inplace: raise NotImplementedError return torch.where(torch.le(a, 0), 0, a) def layer_norm( input: Tensor, normalized_shape: ShapeType, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: float = 1e-5, ) -> Tensor: """ Reference implementation of :func:`torch.nn.functional.layer_norm`. """ return torch.native_layer_norm(input, normalized_shape, weight, bias, eps)[0] @register_decomposition(torch.ops.aten.leaky_relu) @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) def leaky_relu( a: TensorLikeType, negative_slope: float = 0.01, inplace: bool = False ) -> TensorLikeType: """ Reference implementation of torch.nn.functional.leaky_relu """ if inplace: raise NotImplementedError python_type = utils.dtype_to_type(a.dtype) if not utils.is_weakly_lesser_type(type(negative_slope), python_type): msg = f"negative_slope argument of type {type(negative_slope)} cannot be safely cast to type {python_type}!" raise ValueError(msg) return torch.where(torch.gt(a, 0), a, torch.mul(a, negative_slope)) @register_decomposition(torch.ops.aten.mish) @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) def mish(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.mish """ if inplace: raise NotImplementedError return a * torch.tanh(torch.nn.functional.softplus(a)) @register_decomposition(torch.ops.aten.selu) @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) def selu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.selu """ if inplace: raise NotImplementedError alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 rhs = alpha * torch.expm1(a) return scale * torch.where(a > 0, a, rhs) # softplus is implemented specially because it has beta and threshold arguments @register_decomposition(torch.ops.aten.softplus) @out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) def softplus( a: TensorLikeType, beta: Optional[NumberType] = None, threshold: NumberType = 20, inplace: bool = False, ) -> TensorLikeType: """ Reference implementation of torch.nn.functional.softplus """ if inplace: raise NotImplementedError rhs: TensorLikeType if beta is not None: python_type = utils.dtype_to_type(a.dtype) if not utils.is_weakly_lesser_type(type(beta), python_type): msg = "beta argument of type {0} cannot be safely cast to type {1}!".format( type(beta), python_type ) raise ValueError(msg) scaled_input = a * beta rhs = torch.true_divide(torch.log1p(torch.exp(scaled_input)), beta) # type: ignore[arg-type] else: scaled_input = a rhs = torch.log1p(torch.exp(scaled_input)) return torch.where(scaled_input > threshold, a, rhs) @register_decomposition(torch.ops.aten.hardshrink) @out_wrapper() def hardshrink(a: TensorLikeType, lambd: float = 0.5): # Formula for reference, # hardshrink(x) = x if x > lambd # = x if x < -lambd # = 0 otherwise return refs.where(refs.logical_and(a >= -lambd, a <= lambd), 0, a) @register_decomposition(torch.ops.aten.softshrink) @out_wrapper() def softshrink(a: TensorLikeType, lambd: float = 0.5): # Formula for reference, # softshrink(x) = x - lambd if x > lambd # = x + lambd if x < -lambd # = 0 otherwise check( lambd >= 0, lambda: f"lambda must be greater or equal to 0, but found to be {lambd}", ) ge_mask = a > lambd le_mask = a < -lambd zero_mask = torch.logical_not(refs.logical_or(ge_mask, le_mask)) result = refs.where(ge_mask, a - lambd, a) result = refs.where(le_mask, a + lambd, result) return refs.where(zero_mask, 0, result) # Losses def _reduction_int_to_str(reduction: int) -> str: if reduction == Reduction.NONE.value: return "none" elif reduction == Reduction.MEAN.value: return "mean" elif reduction == Reduction.SUM.value: return "sum" else: raise ValueError(f"{reduction} is not a valid value for reduction") def _apply_loss_reduction(loss: TensorLikeType, reduction: str) -> TensorLikeType: if reduction == "sum": return refs.sum(loss) elif reduction == "mean": return refs.mean(loss) else: # reduction == "none" return loss def _check_reduction_value(reduction: str): if reduction not in ("mean", "sum", "none"): raise ValueError(f"{reduction} is not a valid value for reduction") # This helper function maps depreciated arguments, "size_average" and "reduce" # to their corresponding "reduction" string argument def _get_string_reduction_arg( *, size_average: Optional[bool], reduce: Optional[bool] ) -> str: if size_average is None: size_average = True if reduce is None: reduce = True if size_average and reduce: ret = "mean" elif reduce: ret = "sum" else: ret = "none" return ret # CompositeImplicitAutograd - don't register decomp @elementwise_type_promotion_wrapper( type_promoting_args=("input", "target"), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, ) def l1_loss( input: TensorLikeType, target: TensorLikeType, size_average: Optional[bool] = None, reduce: Optional[bool] = None, reduction: str = "mean", ) -> TensorLikeType: """ Reference implementation of torch.nn.functional.l1_loss """ if size_average is not None or reduce is not None: # TODO: raise exception instead of converting value # msg = "size_average and reduce args are deprecated, please use reduction argument." reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) _check_reduction_value(reduction) loss = torch.abs(input - target) return _apply_loss_reduction(loss, reduction) @register_decomposition(torch.ops.aten.margin_ranking_loss) def margin_ranking_loss( input1: TensorLikeType, input2: TensorLikeType, target: TensorLikeType, margin: float = 0.0, reduction: str = "mean", ) -> TensorLikeType: # Formula of loss (implementation gets confusing with all the refs.foo) # loss_without_reduction = max(0, −target * (input1 − input2) + margin) if input1.ndim != input2.ndim or input1.ndim != target.ndim: raise RuntimeError( ( "margin_ranking_loss : All input tensors should have same dimension but got sizes: " "input1: {}, input2: {}, target: {} ".format( input1.shape, input2.shape, target.shape ) ) ) _check_reduction_value(reduction) neg_target = refs.neg(target) input_diff = refs.sub(input1, input2) mul_target_input = refs.mul(neg_target, input_diff) add_margin = refs.add(mul_target_input, margin) loss = refs.maximum(add_margin, 0) return _apply_loss_reduction(loss, reduction) def mse_loss( input: TensorLikeType, target: TensorLikeType, size_average: Optional[bool] = None, reduce: Optional[bool] = None, reduction: str = "mean", ) -> TensorLikeType: if size_average is not None or reduce is not None: # TODO: raise exception instead of converting value # msg = "size_average and reduce args are deprecated, please use reduction argument." reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) _check_reduction_value(reduction) loss = torch.pow(input - target, 2) return _apply_loss_reduction(loss, reduction) @register_decomposition(torch.ops.aten.hinge_embedding_loss) def hinge_embedding_loss( input: TensorLikeType, target: TensorLikeType, margin: float = 1.0, reduction: str = "mean", ) -> TensorLikeType: # Formula of loss (implementation gets confusing with all the refs.foo) # loss_without_reduction = input if y == 1 # = max(0, margin - input) if y == -1 _check_reduction_value(reduction) margin_clamp = refs.maximum(refs.sub(margin, input), 0) output_margin = refs.where(refs.ne(target, 1), margin_clamp, 0) output_self = refs.where(refs.ne(target, -1), input, 0) loss = refs.add(output_margin, output_self) return _apply_loss_reduction(loss, reduction) # TODO: This ref supports int reduction and out kwarg to be compatible with ATen: # https://github.com/pytorch/pytorch/issues/83931 # TODO: Could be rewritten to support complex: # https://github.com/pytorch/pytorch/pull/85041 @register_decomposition(torch.ops.aten.huber_loss) @out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("input", "target"), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) def huber_loss( input: TensorLikeType, target: TensorLikeType, reduction: Union[str, int] = "mean", delta: float = 1.0, ) -> TensorLikeType: """ Reference implementation of torch.nn.functional.huber_loss """ if type(reduction) is int: reduction = _reduction_int_to_str(reduction) _check_reduction_value(reduction) # type: ignore[arg-type] check( delta > 0, lambda: "huber_loss does not support non-positive values for delta.", ) z = (input - target).abs() loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta)) return _apply_loss_reduction(loss, reduction) # type: ignore[arg-type] # tanhshrink does not use _make_elementwise_unary_reference because it does not support out @elementwise_unary_scalar_wrapper @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, ) def tanhshrink(a: TensorLikeType) -> TensorLikeType: """ Reference implementation of torch.nn.functional.tanhshrink """ if not isinstance(a, TensorLike): raise RuntimeError( "Expected a tensor input for an elementwise unary operation!" ) return refs.sub(a, refs.tanh(a)) @register_decomposition(torch.ops.aten.threshold) @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) def threshold( a: TensorLikeType, threshold: NumberType, value: Union[bool, int, float], inplace: bool = False, ) -> TensorLikeType: """ Reference implementation of torch.nn.functional.threshold """ if inplace: raise NotImplementedError return torch.where(a <= threshold, value, a) @register_decomposition(torch.ops.aten.hardtanh) @elementwise_unary_scalar_wrapper @elementwise_type_promotion_wrapper( type_promoting_args=("a"), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) def hardtanh( a: TensorLikeType, min_val: NumberType = -1, max_val: NumberType = 1, inplace: bool = False, ) -> TensorLikeType: """ Reference implementation of torch.nn.functional.hardtanh """ if inplace: raise NotImplementedError if utils.is_boolean_dtype(a.dtype): raise RuntimeError("Bool inputs not supported for hardtanh") # preserve legacy behavior of boundaries not causing type promotion if utils.is_integer_dtype(a.dtype): min_val = int(min_val) # type: ignore[arg-type] max_val = int(max_val) # type: ignore[arg-type] if not (a.dtype != torch.uint8 or (min_val >= 0 and max_val >= 0)): raise RuntimeError( "Cannot do hardtanh on an unsigned type with negative limits" ) return torch.clamp(a, min_val, max_val) # type: ignore[arg-type] @register_decomposition(torch.ops.aten.gelu) @out_wrapper() @elementwise_unary_scalar_wrapper @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) def gelu(a: TensorLikeType, approximate: str = "none") -> TensorLikeType: """ Reference implementation of torch.nn.functional.gelu """ if not isinstance(a, TensorLike): raise RuntimeError( "Expected a tensor input for an elementwise unary operation!" ) M_SQRT2 = 1.41421356237309504880 M_SQRT1_2 = 0.70710678118654752440 M_2_SQRTPI = 1.12837916709551257390 if approximate == "tanh": kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 kKappa = 0.044715 a_cube = a * a * a inner = kBeta * (a + kKappa * a_cube) return 0.5 * a * (1 + torch.tanh(inner)) elif approximate == "none": kAlpha = M_SQRT1_2 return a * 0.5 * (1 + torch.erf(a * kAlpha)) else: raise RuntimeError("approximate argument must be either none or tanh.") # CompositeImplicitAutograd - don't register decomp @elementwise_type_promotion_wrapper( type_promoting_args=("input", "target"), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, ) def poisson_nll_loss( input: TensorLikeType, target: TensorLikeType, log_input: bool = True, full: bool = False, size_average: Optional[bool] = None, eps: float = 1e-8, reduce: Optional[bool] = None, reduction: str = "mean", ) -> TensorLikeType: """ Reference implementation of torch.nn.functional.poisson_nll_loss """ if size_average is not None or reduce is not None: # TODO: raise exception instead of converting value # msg = "size_average and reduce args are deprecated, please use reduction argument." reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) _check_reduction_value(reduction) if log_input: loss = torch.exp(input) - target * input else: loss = input - target * torch.log(input + eps) if full: stirling_term = ( target * torch.log(target) - target + 0.5 * torch.log(2 * torch.pi * target) ) # avoid inplace add loss = loss + stirling_term.masked_fill(target <= 1, 0) return _apply_loss_reduction(loss, reduction) @register_decomposition(torch.ops.aten.prelu) @elementwise_type_promotion_wrapper( type_promoting_args=("a", "weight"), type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType: """ Reference implementation of torch.nn.functional.prelu """ check( isinstance(a, TensorLike), lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}", ) check( isinstance(weight, TensorLike), lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}", ) if weight.numel() != 1: check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.") channel_size = a.shape[1] if a.ndim >= 2 else 1 check( weight.numel() == channel_size, lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers =" f" {weight.numel()} and channel size = {channel_size}.", ) check( weight.ndim == 0 or weight.ndim == 1, lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: " f"ndim = {weight.ndim}", ) weight = prims.broadcast_in_dim( weight, a.shape, tuple() if weight.ndim == 0 else (1,) ) return refs.where(a > 0, a, a * weight) @register_decomposition(torch.ops.aten.relu6) def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.relu6 """ if inplace: raise NotImplementedError # See https://github.com/pytorch/pytorch/pull/81142#discussion_r918220126 # It may be better to use clamp here, but we use hardtanh to replicate # the behavior of the existing implementation return refs.nn.functional.hardtanh(a, 0, 6) @register_decomposition(torch.ops.aten.glu) @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) @out_wrapper() def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType: dim = utils.canonicalize_dims(a.ndim, dim) check( a.shape[dim] % 2 == 0, lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}", ) b, c = torch.tensor_split(a, 2, dim) return b * torch.sigmoid(c) @register_decomposition(torch.ops.aten.pairwise_distance) @out_wrapper() def pairwise_distance( x1: TensorLikeType, x2: TensorLikeType, p: NumberType = 2.0, eps: NumberType = 1e-6, keepdim=False, ) -> TensorLikeType: return torch.linalg.vector_norm(x1 - x2 + eps, ord=p, dim=-1, keepdim=keepdim) @register_decomposition(torch.ops.aten.pdist) @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) @out_wrapper() def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType: check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D") check(p >= 0, lambda: "pdist only supports non-negative p values") # For p == 2 we can use an efficient implementation, but other values of p # require creating a much bigger tensor for an intermediate step if p == 2: aTa = torch.mm(a, a.T) aTa_diag = torch.diag(aTa) t = torch.sqrt(torch.clamp(aTa_diag + aTa_diag.unsqueeze(-1) - 2 * aTa, min=0)) else: t = torch.linalg.vector_norm(a.unsqueeze(1) - a, ord=p, dim=2) i = torch.triu_indices(t.shape[0], t.shape[1], offset=1, device=a.device) return t.flatten().index_select(0, i[0] * t.shape[0] + i[1])