import torch.nn as nn def batch_norm_without_running_stats(module: nn.Module): if isinstance(module, nn.modules.batchnorm._BatchNorm) and module.track_running_stats: module.running_mean = None module.running_var = None module.num_batches_tracked = None module.track_running_stats = False def replace_all_batch_norm_modules_(root: nn.Module) -> nn.Module: """ In place updates :attr:`root` by setting the ``running_mean`` and ``running_var`` to be None and setting track_running_stats to be False for any nn.BatchNorm module in :attr:`root` """ # base case batch_norm_without_running_stats(root) for obj in root.modules(): batch_norm_without_running_stats(obj) return root