import torch from torch.fx._symbolic_trace import Tracer from torch.fx.node import Target, Node, Argument from torch.nn.intrinsic import _FusedModule from typing import List, Callable, Tuple, Any, Dict, Optional __all__ = [ "QuantizationTracer", ] class Scope(object): """ Scope object that records the module path and the module type of a module. Scope is used to track the information of the module that contains a Node in a Graph of GraphModule. For example:: class Sub(torch.nn.Module): def forward(self, x): # This will be a call_method Node in GraphModule, # scope for this would be (module_path="sub", module_type=Sub) return x.transpose(1, 2) class M(torch.nn.Module): def __init__(self): self.sub = Sub() def forward(self, x): # This will be a call_method Node as well, # scope for this would be (module_path="", None) x = x.transpose(1, 2) x = self.sub(x) return x """ def __init__(self, module_path: str, module_type: Any): super().__init__() self.module_path = module_path self.module_type = module_type class ScopeContextManager(object): """ A context manager to track the Scope of Node during symbolic tracing. When entering a forward function of a Module, we'll update the scope information of the current module, and when we exit, we'll restore the previous scope information. """ def __init__( self, scope: Scope, current_module: torch.nn.Module, current_module_path: str ): super().__init__() self.prev_module_type = scope.module_type self.prev_module_path = scope.module_path self.scope = scope self.scope.module_path = current_module_path self.scope.module_type = type(current_module) def __enter__(self): return def __exit__(self, *args): self.scope.module_path = self.prev_module_path self.scope.module_type = self.prev_module_type return class QuantizationTracer(Tracer): def __init__( self, skipped_module_names: List[str], skipped_module_classes: List[Callable] ): super().__init__() self.skipped_module_names = skipped_module_names self.skipped_module_classes = skipped_module_classes # NB: initialized the module_type of top level module to None # we are assuming people won't configure the model with the type of top level # module here, since people can use "" for global config # We can change this if there is a use case that configures # qconfig using top level module type self.scope = Scope("", None) self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} self.record_stack_traces = True def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: return ( ( (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) and not isinstance(m, torch.nn.Sequential) ) or module_qualified_name in self.skipped_module_names or type(m) in self.skipped_module_classes or isinstance(m, _FusedModule) ) def call_module( self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any], ) -> Any: module_qualified_name = self.path_of_module(m) # Creating scope with information of current module # scope will be restored automatically upon exit with ScopeContextManager(self.scope, m, module_qualified_name): return super().call_module(m, forward, args, kwargs) def create_node( self, kind: str, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: Optional[str] = None, type_expr: Optional[Any] = None, ) -> Node: node = super().create_node(kind, target, args, kwargs, name, type_expr) self.node_name_to_scope[node.name] = ( self.scope.module_path, self.scope.module_type, ) return node