import torch.distributed._shard.sharded_tensor._ops.chunk import torch.distributed._shard.sharded_tensor._ops.elementwise_ops import torch.distributed._shard.sharded_tensor._ops.math_ops import torch.distributed._shard.sharded_tensor._ops.matrix_ops import torch.distributed._shard.sharded_tensor._ops.tensor_ops import torch.distributed._shard.sharded_tensor._ops.misc_ops from .binary_cmp import equal, allclose from .init import kaiming_uniform_, normal_, uniform_, constant_ # Import all ChunkShardingSpec ops from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.linear import sharded_linear from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import sharded_embedding from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import sharded_embedding_bag from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.softmax import sharded_softmax import torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.math_ops import torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.matrix_ops