# coding=utf-8 # Copyright 2020 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Convert Reformer checkpoint.""" import argparse import logging import pickle import numpy as np import torch from transformers import ReformerConfig, ReformerModelWithLMHead logging.basicConfig(level=logging.INFO) def set_param(torch_layer, weight, bias=None): # set parameter of one layer assert torch_layer.weight.shape == weight.shape, "{} layer.weight does not match".format(torch_layer) torch_layer.weight = torch.nn.Parameter(weight) if bias is not None: assert torch_layer.bias.shape == bias.shape, "{} layer.bias does not match".format(torch_layer) torch_layer.bias = torch.nn.Parameter(bias) def set_layer_weights_in_torch_lsh(weights, torch_layer, hidden_size): # set torch weights for 1-to-1 comparison np_query_key = np.asarray(weights[0]) np_value = np.asarray(weights[1]) np_dense = np.asarray(weights[2]) set_param( torch_layer.self_attention.query_key, torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size), ) set_param( torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), ) set_param( torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), ) def set_layer_weights_in_torch_local(weights, torch_layer, hidden_size): # set torch weights for 1-to-1 comparison np_query = np.asarray(weights[0]) np_key = np.asarray(weights[1]) np_value = np.asarray(weights[2]) np_dense = np.asarray(weights[3]) set_param( torch_layer.self_attention.query, torch.tensor(np_query).transpose(1, 2).contiguous().view(-1, hidden_size), ) set_param( torch_layer.self_attention.key, torch.tensor(np_key).transpose(1, 2).contiguous().view(-1, hidden_size), ) set_param( torch_layer.self_attention.value, torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size), ) set_param( torch_layer.output.dense, torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1), ) def set_block_weights_in_torch(weights, torch_block, hidden_size): # layernorm 1 layer_norm_1 = weights[0][0][0] layer_norm_1_weight = np.asarray(layer_norm_1[0]) layer_norm_1_bias = np.asarray(layer_norm_1[1]) set_param( torch_block.attention.layer_norm, torch.tensor(layer_norm_1_weight), torch.tensor(layer_norm_1_bias), ) # lsh weights + output attn_weights = weights[0][1] if len(attn_weights) < 4: set_layer_weights_in_torch_lsh(attn_weights, torch_block.attention, hidden_size) else: set_layer_weights_in_torch_local(attn_weights, torch_block.attention, hidden_size) # intermediate weighs intermediate_weights = weights[2][0][1][2] # Chunked Feed Forward if len(intermediate_weights) == 4: intermediate_weights = intermediate_weights[2] # layernorm 2 layer_norm_2_weight = np.asarray(intermediate_weights[0][0]) layer_norm_2_bias = np.asarray(intermediate_weights[0][1]) set_param( torch_block.feed_forward.layer_norm, torch.tensor(layer_norm_2_weight), torch.tensor(layer_norm_2_bias), ) # intermediate dense inter_dense_weight = np.asarray(intermediate_weights[1][0]) inter_dense_bias = np.asarray(intermediate_weights[1][1]) set_param( torch_block.feed_forward.dense.dense, torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(), torch.tensor(inter_dense_bias), ) # intermediate out out_dense_weight = np.asarray(intermediate_weights[4][0]) out_dense_bias = np.asarray(intermediate_weights[4][1]) set_param( torch_block.feed_forward.output.dense, torch.tensor(out_dense_weight).transpose(0, 1).contiguous(), torch.tensor(out_dense_bias), ) def set_model_weights_in_torch(weights, torch_model, hidden_size): # reformer model torch_model_reformer = torch_model.reformer # word embeds word_embeddings = np.asarray(weights[1]) set_param( torch_model_reformer.embeddings.word_embeddings, torch.tensor(word_embeddings), ) if isinstance(weights[3], tuple): position_embeddings = torch_model_reformer.embeddings.position_embeddings for emb_idx in range(len(position_embeddings.weights)): emb_weights = np.asarray(weights[3][emb_idx][0]) assert position_embeddings.weights[emb_idx].shape == emb_weights.shape, "{} emb does not match".format( position_embeddings[emb_idx] ) position_embeddings.weights[emb_idx] = torch.nn.Parameter(torch.tensor(emb_weights)) trax_layer_weights = weights[5] assert len(torch_model_reformer.encoder.layers) * 4 == len( trax_layer_weights ), "HF and trax model do not have the same number of layers" for layer_idx, layer in enumerate(torch_model_reformer.encoder.layers): block_weights = trax_layer_weights[4 * layer_idx : 4 * (layer_idx + 1)] set_block_weights_in_torch(block_weights, layer, hidden_size) # output layer norm layer_norm_out_weight = np.asarray(weights[7][0]) layer_norm_out_bias = np.asarray(weights[7][1]) set_param( torch_model_reformer.encoder.layer_norm, torch.tensor(layer_norm_out_weight), torch.tensor(layer_norm_out_bias), ) # output embeddings output_embed_weights = np.asarray(weights[9][0]) output_embed_bias = np.asarray(weights[9][1]) set_param( torch_model.lm_head.decoder, torch.tensor(output_embed_weights).transpose(0, 1).contiguous(), torch.tensor(output_embed_bias), ) def convert_trax_checkpoint_to_pytorch(trax_model_pkl_path, config_file, pytorch_dump_path): # Initialise PyTorch model config = ReformerConfig.from_json_file(config_file) print("Building PyTorch model from configuration: {}".format(str(config))) model = ReformerModelWithLMHead(config) with open(trax_model_pkl_path, "rb") as f: model_weights = pickle.load(f)["weights"] set_model_weights_in_torch(model_weights, model, config.hidden_size) # Save pytorch-model print("Save PyTorch model to {}".format(pytorch_dump_path)) torch.save(model.state_dict(), pytorch_dump_path) if __name__ == "__main__": parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--trax_model_pkl_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." ) parser.add_argument( "--config_file", default=None, type=str, required=True, help="The config json file corresponding to the pre-trained Reformer model. \n" "This specifies the model architecture.", ) parser.add_argument( "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." ) args = parser.parse_args() convert_trax_checkpoint_to_pytorch(args.trax_model_pkl_path, args.config_file, args.pytorch_dump_path)