import copy import torch from torch.distributed._shard.sharded_tensor import ( Shard, ShardedTensor, ) from ._common import ( _register_sharded_op_on_local_shards, ) def sharded_type_as_check(*args, **kwargs): """ Perform extra checks for the sharded_type_as op such as the input needs to be either a Tensor or ShardedTensor. Args: same as ``torch.Tensor.type_as``. Return: None """ if len(args) < 2: raise ValueError("Needs to give a tensor to cast type as!") if not isinstance(args[1], torch.Tensor) and not isinstance(args[1], ShardedTensor): raise ValueError("Needs to give a Tensor or ShardedTensor to cast type as!") def same_dtype(*args, **kwargs): """ When the dtype is the same, return the original ShardedTensor. Args: same as ``torch.Tensor.type_as``. Return (bool): Whether to return early or not. """ return args[0].dtype == args[1].dtype def sharded_type_as(args, kwargs, pg): """ Handles ``__torch_function__`` dispatch for the ``torch.Tensor.type_as`` op. Args: same as ``torch.Tensor.type_as``. Return: new_local_shards (List[Shard]): Local shards for the new sharded tensor. st_meta (ShardedTensorMetadata): Metadata of the new sharded tensor. """ st = args[0] tensor = args[1] if isinstance(tensor, ShardedTensor): tensor = tensor.local_tensor() new_local_shards = [] for shard in st.local_shards(): new_local_shards.append(Shard(shard.tensor.type_as(tensor), shard.metadata)) st_meta = copy.deepcopy(st._metadata) st_meta.tensor_properties.dtype = tensor.dtype return new_local_shards, st_meta _register_sharded_op_on_local_shards( torch.Tensor.type_as, early_stop_func=same_dtype, extra_check=sharded_type_as_check, customized_func=sharded_type_as, )