import torch import torch.fx import traceback from torch.fx.node import Node, map_aggregate from typing import Any, Tuple, NamedTuple, Optional, Dict from torch.fx._compatibility import compatibility __all__ = ['TensorMetadata', 'ShapeProp'] @compatibility(is_backward_compatible=True) class TensorMetadata(NamedTuple): # TensorMetadata is a structure containing pertinent information # about a tensor within a PyTorch program. # General Tensor metadata shape : torch.Size dtype : torch.dtype requires_grad : bool stride : Tuple[int] memory_format : Optional[torch.memory_format] # Quantization metadata is_quantized : bool qparams: Dict[str, Any] def _extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata: """ Extract a TensorMetadata NamedTuple describing `result`. """ shape = result.shape dtype = result.dtype requires_grad = result.requires_grad stride = result.stride() memory_formats = { torch.contiguous_format, torch.channels_last, torch.channels_last_3d, } memory_format = None for query_format in memory_formats: if result.is_contiguous(memory_format=query_format): memory_format = query_format break is_quantized = result.is_quantized qparams: Dict[str, Any] = {} if is_quantized: qscheme = result.qscheme() qparams["qscheme"] = qscheme if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: qparams["scale"] = result.q_scale() # type: ignore[assignment] qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment] elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}: # In this branch, scale and zero_point are expected to be tensors, # we store the values as immutable_list in TensorMetadata for # easier serialization downstream qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment] qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] return TensorMetadata( shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) @compatibility(is_backward_compatible=True) class ShapeProp(torch.fx.Interpreter): """ Execute an FX graph Node-by-Node and record the shape and type of the result into the corresponding node. Example: In this example, we record the shape and data type of a module given an example input ``torch.randn(50, D_in)``. We print the name, shape and dtype of each node. class TwoLayerNet(torch.nn.Module): def __init__(self, D_in, H, D_out): super(TwoLayerNet, self).__init__() self.linear1 = torch.nn.Linear(D_in, H) self.linear2 = torch.nn.Linear(H, D_out) def forward(self, x): h_relu = self.linear1(x).clamp(min=0) y_pred = self.linear2(h_relu) return y_pred N, D_in, H, D_out = 64, 1000, 100, 10 x = torch.randn(N, D_in) y = torch.randn(N, D_out) model = TwoLayerNet(D_in, H, D_out) gm = torch.fx.symbolic_trace(model) sample_input = torch.randn(50, D_in) ShapeProp(gm).propagate(sample_input) for node in gm.graph.nodes: print(node.name, node.meta['tensor_meta'].dtype, node.meta['tensor_meta'].shape) The output of this code is: x torch.float32 torch.Size([50, 1000]) linear1 torch.float32 torch.Size([50, 100]) clamp_1 torch.float32 torch.Size([50, 100]) linear2 torch.float32 torch.Size([50, 10]) output torch.float32 torch.Size([50, 10]) Args: module (GraphModule): The module to be executed """ def run_node(self, n : Node) -> Any: try: result = super().run_node(n) except Exception: traceback.print_exc() raise RuntimeError( f"ShapeProp error for: node={n.format_node()} with " f"meta={n.meta}" ) found_tensor = False def extract_tensor_meta(obj): if isinstance(obj, torch.Tensor): nonlocal found_tensor found_tensor = True return _extract_tensor_metadata(obj) else: return obj meta = map_aggregate(result, extract_tensor_meta) if found_tensor: n.meta['tensor_meta'] = meta n.meta['type'] = type(result) return result def propagate(self, *args): """ Run `module` via interpretation and return the result and record the shape and type of each node. Args: *args (Tensor): the sample input. Returns: Any: The value returned from executing the Module """ return super().run(*args)