import torch import torch.nn as nn from torch.distributed._shard.sharded_tensor import ShardedTensor class SimpleMegatronLM(nn.Module): def __init__(self, linear_size, rank=None, dtype=torch.float32): super().__init__() self.fc1 = nn.Linear(*linear_size[0], dtype=dtype) self.gelu = nn.GELU() self.fc2 = nn.Linear(*linear_size[1], dtype=dtype) if rank is not None: self.fc1.cuda(rank) self.fc2.cuda(rank) def forward(self, inp): return self.fc2(self.gelu(self.fc1(inp))) def get_weights(self): if isinstance(self.fc1.weight, ShardedTensor): weight1 = self.fc1.weight.local_tensor() else: weight1 = self.fc1.weight if isinstance(self.fc2.weight, ShardedTensor): weight2 = self.fc2.weight.local_tensor() else: weight2 = self.fc2.weight return (weight1, weight2) def get_biases(self): return (self.fc1.bias, self.fc2.bias) def get_weight_grads(self): return (self.fc1.weight.grad, self.fc2.weight.grad) def get_bias_grads(self): return (self.fc1.bias.grad, self.fc2.bias.grad)