import torch import torch.nn as nn from torch.ao.sparsity.sparsifier.utils import module_to_fqn, fqn_to_module from typing import Dict, List SUPPORTED_MODULES = { nn.Embedding, nn.EmbeddingBag } def _fetch_all_embeddings(model): """Fetches Embedding and EmbeddingBag modules from the model """ embedding_modules = [] stack = [model] while stack: module = stack.pop() for _, child in module.named_children(): fqn_name = module_to_fqn(model, child) if type(child) in SUPPORTED_MODULES: embedding_modules.append((fqn_name, child)) else: stack.append(child) return embedding_modules def post_training_sparse_quantize(model, data_sparsifier_class, sparsify_first=True, select_embeddings: List[nn.Module] = None, **sparse_config): """Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags. The quantization step can happen before or after sparsification depending on the `sparsify_first` argument. Args: - model (nn.Module) model whose embeddings needs to be sparsified - data_sparsifier_class (type of data sparsifier) Type of sparsification that needs to be applied to model - sparsify_first (bool) if true, sparsifies first and then quantizes otherwise, quantizes first and then sparsifies. - select_embeddings (List of Embedding modules) List of embedding modules to in the model to be sparsified & quantized. If None, all embedding modules with be sparsified - sparse_config (Dict) config that will be passed to the constructor of data sparsifier object. Note: 1. When `sparsify_first=False`, quantization occurs first followed by sparsification. - before sparsifying, the embedding layers are dequantized. - scales and zero-points are saved - embedding layers are sparsified and `squash_mask` is applied - embedding weights are requantized using the saved scales and zero-points 2. When `sparsify_first=True`, sparsification occurs first followed by quantization. - embeddings are sparsified first - quantization is applied on the sparsified embeddings """ data_sparsifier = data_sparsifier_class(**sparse_config) # if select_embeddings is None, perform it on all embeddings if select_embeddings is None: embedding_modules = _fetch_all_embeddings(model) else: embedding_modules = [] assert isinstance(select_embeddings, List), "the embedding_modules must be a list of embedding modules" for emb in select_embeddings: assert type(emb) in SUPPORTED_MODULES, "the embedding_modules list must be an embedding or embedding bags" fqn_name = module_to_fqn(model, emb) assert fqn_name is not None, "the embedding modules must be part of input model" embedding_modules.append((fqn_name, emb)) if sparsify_first: # sparsify for name, emb_module in embedding_modules: valid_name = name.replace('.', '_') data_sparsifier.add_data(name=valid_name, data=emb_module) data_sparsifier.step() data_sparsifier.squash_mask() # quantize for _, emb_module in embedding_modules: emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig torch.quantization.prepare(model, inplace=True) torch.quantization.convert(model, inplace=True) else: # quantize for _, emb_module in embedding_modules: emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig torch.quantization.prepare(model, inplace=True) torch.quantization.convert(model, inplace=True) # retrieve scale & zero_points quantize_params: Dict[str, Dict] = {'scales': {}, 'zero_points': {}, 'dequant_weights': {}, 'axis': {}, 'dtype': {}} for name, _ in embedding_modules: quantized_emb = fqn_to_module(model, name) assert quantized_emb is not None # satisfy mypy quantized_weight = quantized_emb.weight() # type: ignore[operator] quantize_params['scales'][name] = quantized_weight.q_per_channel_scales() quantize_params['zero_points'][name] = quantized_weight.q_per_channel_zero_points() quantize_params['dequant_weights'][name] = torch.dequantize(quantized_weight) quantize_params['axis'][name] = quantized_weight.q_per_channel_axis() quantize_params['dtype'][name] = quantized_weight.dtype # attach data to sparsifier data_sparsifier.add_data(name=name.replace('.', '_'), data=quantize_params['dequant_weights'][name]) data_sparsifier.step() data_sparsifier.squash_mask() for name, _ in embedding_modules: quantized_emb = fqn_to_module(model, name) assert quantized_emb is not None # satisfy mypy requantized_vector = torch.quantize_per_channel(quantize_params['dequant_weights'][name], scales=quantize_params['scales'][name], zero_points=quantize_params['zero_points'][name], dtype=quantize_params['dtype'][name], axis=quantize_params['axis'][name]) quantized_emb.set_weight(requantized_vector) # type: ignore[operator]