import math from typing import List, Optional, Union import torch import torch._prims_common as utils from torch import Tensor from torch._prims_common import ( check, corresponding_complex_dtype, corresponding_real_dtype, elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND, ) from torch._prims_common.wrappers import out_wrapper from torch._refs import _broadcast_shapes from torch.utils._pytree import tree_map aten = torch.ops.aten _meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta") meta_table = {} def register_meta(op, register_dispatcher=True): def wrapper(f): def add_func(op): meta_table[op] = f if register_dispatcher: name = ( op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__ ) _meta_lib_dont_use_me_use_register_meta.impl(name, f) op.py_impl(torch._C.DispatchKey.Meta)(f) tree_map(add_func, op) return f return wrapper def toRealValueType(dtype): from_complex = { torch.complex32: torch.half, torch.cfloat: torch.float, torch.cdouble: torch.double, } return from_complex.get(dtype, dtype) @register_meta(aten._fft_c2c.default) def meta_fft_c2c(self, dim, normalization, forward): assert self.dtype.is_complex return self.new_empty(self.size()) @register_meta(aten._fft_r2c.default) def meta_fft_r2c(self, dim, normalization, onesided): assert self.dtype.is_floating_point output_sizes = list(self.size()) if onesided: last_dim = dim[-1] last_dim_halfsize = (output_sizes[last_dim] // 2) + 1 output_sizes[last_dim] = last_dim_halfsize return self.new_empty( output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype) ) @register_meta(aten.randperm.generator_out) def meta_randperm(n, *, generator=None, out): assert out.ndim == 1 and out.size(0) == n return out @register_meta([aten._fft_c2r.default, aten._fft_c2r.out]) @out_wrapper() def meta_fft_c2r(self, dim, normalization, lastdim): assert self.dtype.is_complex output_sizes = list(self.size()) output_sizes[dim[-1]] = lastdim return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype)) @register_meta(aten.copy_.default, register_dispatcher=False) def meta_copy_(self, src, non_blocking=False): return self # Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py @register_meta(aten.index_select.default) def meta_index_select(self, dim, index): result_size = list(self.size()) if self.dim() > 0: result_size[dim] = index.numel() return self.new_empty(result_size) @register_meta(aten.index_select.out) def meta_index_select_out(self, dim, index, out): torch._resize_output_(out, self.size(), self.device) return out.copy_(torch.index_select(self, dim, index)) @register_meta([aten.max.default, aten.min.default]) def meta_max(self): return self.new_empty(()) @register_meta(aten.angle.default) def meta_angle(self): if self.is_complex(): result_dtype = corresponding_real_dtype(self.dtype) else: _, result_dtype = elementwise_dtypes( self, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT ) return self.new_empty(self.size(), dtype=result_dtype) @register_meta(aten.angle.out) def meta_angle_out(self, out): torch._resize_output_(out, self.size(), self.device) return out.copy_(torch.angle(self)) def squareCheckInputs(self, f_name): assert ( self.dim() >= 2 ), f"{f_name}: The input tensor must have at least 2 dimensions." assert self.size(-1) == self.size( -2 ), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices" def checkUplo(uplo: str): uplo_uppercase = uplo.upper() assert ( len(uplo) == 1 and uplo_uppercase == "U" or uplo_uppercase == "L" ), f"Expected UPLO argument to be 'L' or 'U', but got {uplo}" # @register_meta(aten.linalg_eigh.default) def meta_linalg_eigh(self, uplo="L"): squareCheckInputs(self, "linalg_eigh") checkUplo(uplo) real_dtype = toRealValueType(self.dtype) assert self.dim() >= 2 values = self.new_empty(self.shape, dtype=real_dtype) values.transpose_(-2, -1) vectors = self.new_empty(self.shape[:-1]) return (values, vectors) @register_meta(aten.reflection_pad2d.default) def meta_pad2d(self, padding): valid_dims = self.size(1) != 0 and self.size(2) != 0 check( (self.ndim == 3 and valid_dims) or (self.ndim == 4 and valid_dims and self.size(3) != 0), lambda: f"3D or 4D (batch mode) tensor expected for input, but got: {self}", ) if self.ndim == 4: nbatch, nplane, input_h, input_w = self.shape else: nbatch = 1 nplane, input_h, input_w = self.shape pad_l, pad_r, pad_t, pad_b = padding output_h = input_h + pad_t + pad_b output_w = input_w + pad_l + pad_r if self.ndim == 3: return self.new_empty((nplane, output_h, output_w)) else: return self.new_empty((nbatch, nplane, output_h, output_w)) def dot_check(self, other): check( self.dim() == 1 and other.dim() == 1, lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors", ) @register_meta(aten.dot.default) def meta_dot(self, tensor): dot_check(self, tensor) return self.new_empty(()) @register_meta([aten.mm.default], register_dispatcher=False) def meta_mm(a, b): check(a.dim() == 2, lambda: "a must be 2D") check(b.dim() == 2, lambda: "b must be 2D") N, M1 = a.shape M2, P = b.shape check(M1 == M2, lambda: "a and b must have same reduction dim") return a.new_empty(N, P) def _compute_reduction_shape(self, dims, keepdim): if keepdim: return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim)) return utils.compute_reduction_output_shape(self.shape, dims) @register_meta(aten.bernoulli.out) def meta_bernoulli(self, *, generator=None, out): torch._resize_output_(out, self.size(), self.device) return out @register_meta(aten.convolution.default) def meta_conv( input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, ): def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: """ Formula to apply to calculate the length of some dimension of the output See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html Args: ln: length of the dimension p: padding in that dim d: dilation in that dim k: kernel size in that dim s: stride in that dim Returns: The output length """ return (ln + 2 * p - d * (k - 1) - 1) // s + 1 def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int: """ Formula to apply to calculate the length of some dimension of the output if transposed convolution is used. See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html Args: ln: length of the dimension p: padding in that dim d: dilation in that dim k: kernel size in that dim s: stride in that dim op: output padding in that dim Returns: The output length """ return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1 def calc_conv_nd_return_shape( dims: torch.Size, kernel_size: torch.Size, stride: Union[List[int], int], padding: Union[List[int], int], dilation: Union[List[int], int], output_padding: Optional[Union[List[int], int]] = None, ): ret_shape = [] if isinstance(stride, int): stride = [stride] * len(dims) elif len(stride) == 1: stride = [stride[0]] * len(dims) if isinstance(padding, int): padding = [padding] * len(dims) elif len(padding) == 1: padding = [padding[0]] * len(dims) if isinstance(dilation, int): dilation = [dilation] * len(dims) elif len(dilation) == 1: dilation = [dilation[0]] * len(dims) output_padding_list: Optional[List[int]] = None if output_padding: if isinstance(output_padding, int): output_padding_list = [output_padding] * len(dims) elif len(output_padding) == 1: output_padding_list = [output_padding[0]] * len(dims) else: output_padding_list = output_padding for i in range(len(dims)): # If output_padding is present, we are dealing with a transposed convolution if output_padding_list: ret_shape.append( _formula_transposed( dims[i], padding[i], dilation[i], kernel_size[i], stride[i], output_padding_list[i], ) ) else: ret_shape.append( _formula( dims[i], padding[i], dilation[i], kernel_size[i], stride[i] ) ) return ret_shape def is_channels_last(ten): return torch._prims_common.suggest_memory_format(ten) == torch.channels_last def pick_memory_format(device_hint): if device_hint == "cuda": if is_channels_last(input_tensor) or is_channels_last(weight): return torch.channels_last else: if is_channels_last(input_tensor): return torch.channels_last if input_tensor.is_contiguous(memory_format=torch.contiguous_format): return torch.contiguous_format elif input_tensor.is_contiguous(memory_format=torch.preserve_format): return torch.preserve_format kernel_size = weight.shape[2:] dims = input_tensor.shape[2:] if is_transposed: out_channels = groups * weight.shape[1] shape_out = calc_conv_nd_return_shape( dims, kernel_size, stride, padding, dilation, output_padding, ) else: out_channels = weight.shape[0] if weight.shape[1] * groups != input_tensor.shape[1]: raise RuntimeError("Invalid channel dimensions") shape_out = calc_conv_nd_return_shape( dims, kernel_size, stride, padding, dilation ) out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) from torch._subclasses.fake_tensor import FakeTensor if isinstance(input_tensor, FakeTensor): device_hint = input_tensor.fake_device.type else: device_hint = "cuda" # default to cuda mem_fmt = pick_memory_format(device_hint) out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] return out # from check_dim_size() in aten/src/ATen/TensorUtils.cpp. def check_dim_size(tensor, dim, dim_size, size): check( tensor.dim() == dim and tensor.shape[dim_size] == size, lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, " + f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}", ) @register_meta(aten.avg_pool2d.default, register_dispatcher=False) def meta_avg_pool2d( input, kernel_size, stride=(), padding=(0,), ceil_mode=False, count_include_pad=True, divisor_override=None, ): def unpack(name, val): check( len(val) in [1, 2], lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints", ) H = val[0] W = H if len(val) == 1 else val[1] return H, W kH, kW = unpack("kernel_size", kernel_size) check( len(stride) in [0, 1, 2], lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints", ) if len(stride) == 0: dH, dW = kH, kW elif len(stride) == 1: dH, dW = stride[0], stride[0] else: dH, dW = unpack("stride", stride) padH, padW = unpack("padding", padding) check( divisor_override is None or divisor_override != 0, lambda: "divisor must be not zero", ) nbatch = input.size(-4) if input.dim() == 4 else 1 nInputPlane = input.size(-3) inputHeight = input.size(-2) inputWidth = input.size(-1) outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode) outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode) memory_format = utils.suggest_memory_format(input) pool2d_shape_check( input, kH, kW, dH, dW, padH, padW, 1, 1, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format, ) if input.dim() == 3: size = [nInputPlane, outputHeight, outputWidth] else: size = [nbatch, nInputPlane, outputHeight, outputWidth] return torch.empty( size, dtype=input.dtype, device=input.device, memory_format=memory_format ) # from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h. def avg_pool2d_backward_shape_check( input, gradOutput, nbatch, kH, kW, dH, dW, padH, padW, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, mem_format, ): pool2d_shape_check( input, kH, kW, dH, dW, padH, padW, 1, 1, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, mem_format, ) ndim = input.dim() nOutputPlane = nInputPlane check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane) check_dim_size(gradOutput, ndim, ndim - 2, outputHeight) check_dim_size(gradOutput, ndim, ndim - 1, outputWidth) @register_meta(aten._adaptive_avg_pool2d.default) def meta_adaptive_avg_pool2d(self, output_size): check( self.ndim == 3 or self.ndim == 4, lambda: f"Expected 3D or 4D tensor, but got {self.shape}", ) return self.new_empty(self.shape[:-2] + tuple(output_size)) @register_meta(aten._adaptive_avg_pool3d.default) def meta_adaptive_avg_pool3d(self, output_size): check( self.ndim == 4 or self.ndim == 5, lambda: f"Expected 4D or 5D tensor, but got {self.shape}", ) return self.new_empty(self.shape[:-3] + tuple(output_size)) @register_meta(aten.repeat_interleave.Tensor) def meta_repeat_interleave_Tensor(repeats, output_size=None): if output_size is None: raise RuntimeError("cannot repeat_interleave a meta tensor without output_size") return repeats.new_empty(output_size) @register_meta([aten.complex.default, aten.complex.out]) @out_wrapper() def meta_complex(real, imag): assert real.dtype.is_floating_point assert imag.dtype.is_floating_point out_shape = _broadcast_shapes(real.shape, imag.shape) return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype)) @register_meta(aten.vdot.default) def vdot(self, other): if not self.is_complex: return torch.dot(self, other) if self.is_conj(): if other.is_conj(): return torch.vdot(other.conj(), self.conj()) else: return torch.dot(self.conj(), other) elif other.is_conj(): return torch.dot(self, other.conj()).conj() dot_check(self, other) return self.new_empty(()) # Leaving this function around because a python implementation # of indexing shape inference is useful, # but not registering it to the dispatcher because we already # get shape inference through structured kernels @register_meta(aten.index.Tensor, register_dispatcher=False) def meta_index_Tensor(self, indices): check(indices, lambda: "at least one index must be provided") # aten::index is the internal advanced indexing implementation # checkIndexTensorTypes and expandTensors result: List[Optional[Tensor]] = [] for i, index in enumerate(indices): if index is not None: check( index.dtype in [torch.long, torch.int8, torch.bool], lambda: "tensors used as indices must be long, byte or bool tensors", ) if index.dtype in [torch.int8, torch.bool]: nonzero = index.nonzero() k = len(result) check( k + index.ndim <= self.ndim, lambda: f"too many indices for tensor of dimension {self.ndim}", IndexError, ) for j in range(index.ndim): check( index.shape[j] == self.shape[k + j], lambda: f"The shape of the mask {index.shape} at index {i} " f"does not match the shape of the indexed tensor {self.shape} at index {k + j}", IndexError, ) result.append(nonzero.select(1, j)) else: result.append(index) else: result.append(index) indices = result check( len(indices) <= self.ndim, lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})", ) # expand_outplace import torch._refs as refs # avoid import cycle in mypy indices = list(refs._maybe_broadcast(*indices)) # add missing null tensors while len(indices) < self.ndim: indices.append(None) # hasContiguousSubspace # true if all non-null tensors are adjacent # See: # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency state = 0 has_contiguous_subspace = False for index in indices: if state == 0: if index is not None: state = 1 elif state == 1: if index is None: state = 2 else: if index is not None: break else: has_contiguous_subspace = True # transposeToFront # This is the logic that causes the newly inserted dimensions to show up # at the beginning of the tensor, if they're not contiguous if not has_contiguous_subspace: dims = [] transposed_indices = [] for i, index in enumerate(indices): if index is not None: dims.append(i) transposed_indices.append(index) for i, index in enumerate(indices): if index is None: dims.append(i) transposed_indices.append(index) self = self.permute(dims) indices = transposed_indices # AdvancedIndex::AdvancedIndex # Now we can assume the indices have contiguous subspace # This is simplified from AdvancedIndex which goes to more effort # to put the input and indices in a form so that TensorIterator can # take them. If we write a ref for this, probably that logic should # get implemented before_shape: List[int] = [] after_shape: List[int] = [] replacement_shape: List[int] = [] for dim, index in enumerate(indices): if index is None: if replacement_shape: after_shape.append(self.shape[dim]) else: before_shape.append(self.shape[dim]) else: replacement_shape = list(index.shape) return self.new_empty(before_shape + replacement_shape + after_shape) @register_meta([aten.addbmm.default, aten.addbmm.out]) @out_wrapper() def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1): dim1 = batch1.size(1) dim2 = batch2.size(2) self = self.expand((dim1, dim2)) check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") check( batch1.size(0) == batch2.size(0), lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}", ) check( batch1.size(2) == batch2.size(1), lambda: ( f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} " f"and {batch2.size(1)}x{batch2.size(2)})" ), ) check( self.size(0) == dim1 and self.size(1) == dim2, lambda: "self tensor does not match matmul output shape", ) return self.new_empty(self.size()) @register_meta(aten._cdist_forward.default) def meta_cdist_forward(x1, x2, p, compute_mode): check( x1.dim() >= 2, lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D", ) check( x2.dim() >= 2, lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D", ) check( x1.size(-1) == x2.size(-1), lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}", ) check( utils.is_float_dtype(x1.dtype), lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}", ) check( utils.is_float_dtype(x2.dtype), lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}", ) check(p >= 0, lambda: "cdist only supports non-negative p values") check( compute_mode >= 0 and compute_mode <= 2, lambda: f"possible modes: 0, 1, 2, but was: {compute_mode}", ) r1 = x1.size(-2) r2 = x2.size(-2) batch_tensor1 = x1.shape[:-2] batch_tensor2 = x2.shape[:-2] output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2)) output_shape.extend([r1, r2]) return x1.new_empty(output_shape) @register_meta(aten._embedding_bag.default) def meta_embedding_bag( weight, indices, offsets, scale_grad_by_freq=False, mode=0, sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=-1, ): check( indices.dtype in (torch.long, torch.int), lambda: f"expected indices to be long or int, got {indices.dtype}", ) check( offsets.dtype in (torch.long, torch.int), lambda: f"expected offsets to be long or int, got {offsets.dtype}", ) check( utils.is_float_dtype(weight.dtype), lambda: f"expected weight to be floating point type, got {weight.dtype}", ) num_bags = offsets.size(0) if include_last_offset: check( num_bags >= 1, lambda: "include_last_offset: numBags should be at least 1" ) num_bags -= 1 output = weight.new_empty(num_bags, weight.size(1)) MODE_SUM, MODE_MEAN, MODE_MAX = range(3) if per_sample_weights is not None: check( mode == MODE_SUM, lambda: "embedding_bag: per_sample_weights only supported with mode='sum'", ) check( per_sample_weights.dtype == weight.dtype, lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype", ) check( per_sample_weights.ndim == 1, lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D", ) check( per_sample_weights.numel() == indices.numel(), lambda: ( f"expected per_sample_weights.numel() ({per_sample_weights.numel()} " f"to be the same as indices.numel() ({indices.numel()})" ), ) def is_fast_path_index_select_scale(src, scale, output, padding_idx): return ( is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1 ) def is_fast_path_index_select(src, output, padding_idx): return ( (src.dtype == torch.float or src.dtype == torch.half) and src.stride(1) == 1 and output.stride(1) == 1 and padding_idx < 0 ) def is_fast_path(src, scale, output, padding_idx): if scale is not None: return is_fast_path_index_select_scale(src, scale, output, padding_idx) else: return is_fast_path_index_select(src, output, padding_idx) if offsets.device.type != "cpu": offset2bag = indices.new_empty(indices.size(0)) bag_size = indices.new_empty(offsets.size()) if mode == MODE_MAX: max_indices = indices.new_empty(num_bags, weight.size(1)) else: max_indices = indices.new_empty(0) else: fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx) if mode == MODE_MEAN or mode == MODE_MAX or not fast_path_sum: offset2bag = offsets.new_empty(indices.size(0)) else: offset2bag = offsets.new_empty(0) bag_size = offsets.new_empty(num_bags) max_indices = offsets.new_empty(bag_size.size()) return output, offset2bag, bag_size, max_indices @register_meta([aten.diag.default, aten.diag.out]) @out_wrapper() def meta_diag(self, dim=0): check(self.dim() in (1, 2), lambda: "matrix or a vector expected") if self.dim() == 1: sz = self.size(0) + abs(dim) return self.new_empty((sz, sz)) # case: dim is 2 if dim >= 0: sz = min(self.size(0), self.size(1) - dim) else: sz = min(self.size(0) + dim, self.size(1)) return self.new_empty((sz,)) @register_meta(aten._embedding_bag_forward_only.default) def meta_embedding_bag_forward_only(weight, indices, offsets, *args): output, offset2bag, bag_size, max_indices = meta_embedding_bag( weight, indices, offsets, *args ) if offsets.device.type == "cpu": bag_size = offsets.new_empty(offsets.size()) return output, offset2bag, bag_size, max_indices def _get_reduction_dtype(input, dtype, promote_int_to_long=True): # if specified, dtype takes precedence if dtype: return dtype if input.dtype.is_floating_point or input.dtype.is_complex: return input.dtype elif promote_int_to_long: return torch.long return input.dtype @register_meta([aten.nansum.default, aten.nansum.out]) @out_wrapper() def meta_nansum(input, dims=None, keepdim=False, *, dtype=None): output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True) dims = utils.reduction_dims(input.shape, dims) output_shape = _compute_reduction_shape(input, dims, keepdim) return input.new_empty(output_shape, dtype=output_dtype) @register_meta(aten.nanmedian.default) def meta_nanmedian(input): output_shape = utils.compute_reduction_output_shape( input.shape, tuple(range(input.dim())) ) return input.new_empty(output_shape) @register_meta([aten.nanmedian.dim, aten.nanmedian.dim_values]) @out_wrapper("values", "indices") def meta_nanmedian_dim(input, dim=-1, keepdim=False): dim = utils.reduction_dims(input.shape, (dim,)) output_shape = _compute_reduction_shape(input, dim, keepdim) return ( input.new_empty(output_shape), input.new_empty(output_shape, dtype=torch.long), ) @register_meta(aten.logical_not_.default) def meta_logical_not_(self): return self @register_meta(aten.repeat.default) def meta_repeat(self, repeats): check( len(repeats) >= self.dim(), lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor", ) # Add new leading dimensions to the tensor if the # number of target dimensions is larger than the # number of source dimensions. num_new_dimensions = len(repeats) - self.dim() padded_size = (1,) * num_new_dimensions + tuple(self.shape) target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))] return self.new_empty(target_size) @register_meta(aten.zero_.default, register_dispatcher=False) def meta_zero_(self): return self @register_meta( [aten.fill.Tensor, aten.fill.Scalar, aten.fill_.Tensor, aten.fill_.Scalar], register_dispatcher=False, ) def meta_fill_(self, val): return self @register_meta(aten.relu_.default, register_dispatcher=False) def meta_relu_(self): return self @register_meta(aten.index_put.default, register_dispatcher=False) def meta_index_put(self, indices, values, accumulate=False): return self.new_empty(self.size()) @register_meta(aten.masked_fill_.Scalar, register_dispatcher=False) def meta_masked_fill_(self, mask, value): return self @register_meta(aten.index_put_.default, register_dispatcher=False) def meta_index_put_(self, indices, values, accumulate=False): return self @register_meta(aten.alias.default, register_dispatcher=False) def meta_alias(self): return self.view(self.shape) def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None): check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") batch1_sizes = batch1.size() batch2_sizes = batch2.size() bs = batch1_sizes[0] contraction_size = batch1_sizes[2] res_rows = batch1_sizes[1] res_cols = batch2_sizes[2] output_size = (bs, res_rows, res_cols) check( batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size, lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}" f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].", ) # TODO: handle out output = batch2.new_empty(output_size) if not is_bmm and self_baddbmm is not None: check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor") check( self_baddbmm.size() == output_size, lambda: "Expected an input tensor shape with shape {output_size} but got shape: {self.size()}", ) return output @register_meta(aten.bmm.default, register_dispatcher=False) def meta_bmm(self, mat2): return common_meta_baddbmm_bmm(self, mat2, True) def div_rtn(x, y): q = x // y r = x % y # WARNING: explicit bool conversion here is necessary; # would be fixed by SymBool if r != 0 and (bool(r < 0) != bool(y < 0)): q -= 1 return q def pooling_output_shape_pad_lr( inputSize, kernelSize, pad_l, pad_r, stride, dilation, ceil_mode ): outputSize = ( div_rtn( inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 + (stride - 1 if ceil_mode else 0), stride, ) + 1 ) if ceil_mode: if (outputSize - 1) * stride >= inputSize + pad_l: outputSize -= 1 return outputSize def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode): check(stride != 0, lambda: "stride should not be zero") check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}") check( pad <= kernelSize // 2, lambda: f"pad should be at most half of kernel size, but got pad={pad} and kernel_size={kernelSize}", ) return pooling_output_shape_pad_lr( inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode ) def pool2d_shape_check( input, kH, kW, dH, dW, padH, padW, dilationH, dilationW, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format, ): ndim = input.dim() nOutputPlane = nInputPlane check( kW > 0 and kH > 0, lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}", ) check( dW > 0 and dH > 0, lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}", ) check( dilationH > 0 and dilationW > 0, lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}", ) valid_dims = input.size(1) != 0 and input.size(2) != 0 if memory_format == torch.channels_last: check( ndim == 4 and valid_dims and input.size(3) != 0, lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout" " with optional 0 dim batch size for input, but got: {input.size()}", ) else: check( (ndim == 3 and input.size(0) != 0 and valid_dims) or (ndim == 4 and valid_dims and input.size(3) != 0), lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}", ) check( kW // 2 >= padW and kH // 2 >= padH, lambda: "pad should be smaller than or equal to half of kernel size, but got " f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}", ) check( outputWidth >= 1 and outputHeight >= 1, lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). " f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). " "Output size is too small", ) @register_meta(aten.max_pool2d_with_indices.default, register_dispatcher=False) def meta_max_pool2d_with_indices( input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False ): # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp def unpack(name, val): check( len(val) in [1, 2], lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints", ) H = val[0] W = H if len(val) == 1 else val[1] return H, W kH, kW = unpack("kernel_size", kernel_size) check( len(stride) in [0, 1, 2], lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints", ) if len(stride) == 0: dH, dW = kH, kW else: dH, dW = unpack("stride", stride) padH, padW = unpack("padding", padding) dilationH, dilationW = unpack("dilation", dilation) memory_format = utils.suggest_memory_format(input) if memory_format == torch.channels_last: check( input.dim() == 4, lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout", ) elif memory_format == torch.contiguous_format: check( input.dim() in [3, 4], lambda: "non-empty 3D or 4D (batch mode) tensor expected for input", ) else: check( False, lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous", ) nbatch = input.size(-4) if input.dim() == 4 else 1 nInputPlane = input.size(-3) inputHeight = input.size(-2) inputWidth = input.size(-1) outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode) outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode) pool2d_shape_check( input, kH, kW, dH, dW, padH, padW, dilationH, dilationW, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format, ) if input.dim() == 3: size = [nInputPlane, outputHeight, outputWidth] else: size = [nbatch, nInputPlane, outputHeight, outputWidth] return ( torch.empty( size, dtype=input.dtype, device=input.device, memory_format=memory_format ), torch.empty( size, dtype=torch.int64, device=input.device, memory_format=memory_format ), ) @register_meta([aten.full.default]) def full(size, fill_value, *args, **kwargs): return torch.empty(size, *args, **kwargs) @register_meta( [ aten.randint_like.default, aten.randint_like.low_dtype, aten.randn_like.default, aten.rand_like.default, aten.full_like.default, aten.zeros_like.default, aten.ones_like.default, ] ) def meta_like(self, *args, **kwargs): return aten.empty_like.default(self, **kwargs) # hacky: Please remove after math.ceil works with arange @register_meta(aten.arange.default) def arange(end, **kwargs): if isinstance(end, float): end = math.ceil(end) def is_integral(x): return isinstance(x, int) or isinstance(x, bool) set_to_integral_dtype = kwargs.get("dtype", None) is None and is_integral(end) if set_to_integral_dtype: kwargs["dtype"] = torch.int64 return aten.empty([end], **kwargs) @register_meta(aten.arange.start) def arange_start(start, end, **kwargs): return aten.arange(end - start, **kwargs) # We must also trigger meta registrations from PrimTorch ref # decompositions import torch._refs import torch._refs.nn.functional import torch._refs.special