# coding=utf-8 # Copyright 2018 T5 Authors and The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ TF 2.0 T5 model. """ import copy import itertools import logging import math import tensorflow as tf from .configuration_t5 import T5Config from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list logger = logging.getLogger(__name__) TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP = { "t5-small": "https://cdn.huggingface.co/t5-small-tf_model.h5", "t5-base": "https://cdn.huggingface.co/t5-base-tf_model.h5", "t5-large": "https://cdn.huggingface.co/t5-large-tf_model.h5", "t5-3b": "https://cdn.huggingface.co/t5-3b-tf_model.h5", "t5-11b": "https://cdn.huggingface.co/t5-11b-tf_model.h5", } #################################################### # TF 2.0 Models are constructed using Keras imperative API by sub-classing # - tf.keras.layers.Layer for the layers and # - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model) #################################################### class TFT5LayerNorm(tf.keras.layers.Layer): def __init__(self, epsilon=1e-6, **kwargs): """ Construct a layernorm module in the T5 style No bias and no substraction of mean. """ super().__init__(**kwargs) self.variance_epsilon = epsilon def build(self, input_shape): """Build shared word embedding layer """ self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones") super().build(input_shape) def call(self, x): variance = tf.math.reduce_mean(tf.math.square(x), axis=-1, keepdims=True) x = x * tf.math.rsqrt(variance + self.variance_epsilon) return self.weight * x class TFT5DenseReluDense(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) self.wi = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi") self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo") self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.act = tf.keras.activations.relu def call(self, hidden_states, training=False): h = self.wi(hidden_states) h = self.act(h) h = self.dropout(h, training=training) h = self.wo(h) return h class TFT5LayerFF(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) self.DenseReluDense = TFT5DenseReluDense(config, name="DenseReluDense") self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") self.dropout = tf.keras.layers.Dropout(config.dropout_rate) def call(self, hidden_states, training=False): norm_x = self.layer_norm(hidden_states) y = self.DenseReluDense(norm_x, training=training) layer_output = hidden_states + self.dropout(y, training=training) return layer_output class TFT5Attention(tf.keras.layers.Layer): NEW_ID = itertools.count() def __init__(self, config, has_relative_attention_bias=False, **kwargs): super().__init__(**kwargs) self.layer_id = next(TFT5Attention.NEW_ID) self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias self.output_attentions = config.output_attentions self.relative_attention_num_buckets = config.relative_attention_num_buckets self.d_model = config.d_model self.d_kv = config.d_kv self.n_heads = config.num_heads self.inner_dim = self.n_heads * self.d_kv # Mesh TensorFlow initialization to avoid scaling before softmax self.q = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="q") self.k = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="k") self.v = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="v") self.o = tf.keras.layers.Dense(self.d_model, use_bias=False, name="o") self.dropout = tf.keras.layers.Dropout(config.dropout_rate) if self.has_relative_attention_bias: self.relative_attention_bias = tf.keras.layers.Embedding( self.relative_attention_num_buckets, self.n_heads, name="relative_attention_bias", ) self.pruned_heads = set() def prune_heads(self, heads): raise NotImplementedError @staticmethod def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): """ Adapted from Mesh Tensorflow: https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 Translate relative position to a bucket number for relative attention. The relative position is defined as memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should allow for more graceful generalization to longer sequences than the model has been trained on. Args: relative_position: an int32 Tensor bidirectional: a boolean - whether the attention is bidirectional num_buckets: an integer max_distance: an integer Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ ret = 0 n = -relative_position if bidirectional: num_buckets //= 2 ret += tf.dtypes.cast(tf.math.less(n, 0), tf.int32) * num_buckets n = tf.math.abs(n) else: n = tf.math.maximum(n, 0) # now n is in the range [0, inf) max_exact = num_buckets // 2 is_small = tf.math.less(n, max_exact) val_if_large = max_exact + tf.dtypes.cast( tf.math.log(tf.dtypes.cast(n, tf.float32) / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact), tf.int32, ) val_if_large = tf.math.minimum(val_if_large, num_buckets - 1) ret += tf.where(is_small, n, val_if_large) return ret def compute_bias(self, qlen, klen): """ Compute binned relative position bias """ context_position = tf.range(qlen)[:, None] memory_position = tf.range(klen)[None, :] relative_position = memory_position - context_position # shape (qlen, klen) rp_bucket = self._relative_position_bucket( relative_position, bidirectional=not self.is_decoder, num_buckets=self.relative_attention_num_buckets, ) values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen) return values def call( self, input, mask=None, kv=None, position_bias=None, cache=None, past_key_value_state=None, head_mask=None, query_length=None, use_cache=False, training=False, ): """ Self-attention (if kv is None) or attention over source sentence (provided by kv). """ # Input is (bs, qlen, dim) # Mask is (bs, klen) (non-causal) or (bs, klen, klen) # past_key_value_state[0] is (bs, n_heads, q_len - 1, dim_per_head) bs, qlen, dim = shape_list(input) if past_key_value_state is not None: assert self.is_decoder is True, "Encoder cannot cache past key value states" assert ( len(past_key_value_state) == 2 ), "past_key_value_state should have 2 past states: keys and values. Got {} past states".format( len(past_key_value_state) ) real_qlen = qlen + shape_list(past_key_value_state[0])[2] if query_length is None else query_length else: real_qlen = qlen if kv is None: klen = real_qlen else: klen = shape_list(kv)[1] def shape(x): """ projection """ return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, self.d_kv)), perm=(0, 2, 1, 3)) def unshape(x): """ compute context """ return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.inner_dim)) q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head) if kv is None: k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head) v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head) elif past_key_value_state is None: k = v = kv k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head) v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head) if past_key_value_state is not None: if kv is None: k_, v_ = past_key_value_state k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head) v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head) else: k, v = past_key_value_state # to cope with keras serialization # we need to cast `use_cache` to correct bool # if it is a tensor if tf.is_tensor(use_cache): if hasattr(use_cache, "numpy"): use_cache = bool(use_cache.numpy()) else: use_cache = True if self.is_decoder and use_cache is True: present_key_value_state = ((k, v),) else: present_key_value_state = (None,) scores = tf.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen) if position_bias is None: if not self.has_relative_attention_bias: raise ValueError("No position_bias provided and no weights to compute position_bias") position_bias = self.compute_bias(real_qlen, klen) # if key and values are already calculated # we want only the last query position bias if past_key_value_state is not None: position_bias = position_bias[:, :, -1:, :] if mask is not None: position_bias = position_bias + mask # (bs, n_heads, qlen, klen) scores += position_bias weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen) weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen) # Mask heads if we want to if head_mask is not None: weights = weights * head_mask context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) context = unshape(context) # (bs, qlen, dim) context = self.o(context) outputs = (context,) + present_key_value_state if self.output_attentions: outputs = outputs + (weights,) if self.has_relative_attention_bias: outputs = outputs + (position_bias,) return outputs class TFT5LayerSelfAttention(tf.keras.layers.Layer): def __init__(self, config, has_relative_attention_bias=False, **kwargs): super().__init__(**kwargs) self.SelfAttention = TFT5Attention( config, has_relative_attention_bias=has_relative_attention_bias, name="SelfAttention", ) self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") self.dropout = tf.keras.layers.Dropout(config.dropout_rate) def call( self, hidden_states, attention_mask=None, position_bias=None, head_mask=None, past_key_value_state=None, use_cache=False, training=False, ): norm_x = self.layer_norm(hidden_states) attention_output = self.SelfAttention( norm_x, mask=attention_mask, position_bias=position_bias, head_mask=head_mask, past_key_value_state=past_key_value_state, use_cache=use_cache, training=training, ) y = attention_output[0] layer_output = hidden_states + self.dropout(y, training=training) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them return outputs class TFT5LayerCrossAttention(tf.keras.layers.Layer): def __init__(self, config, has_relative_attention_bias=False, **kwargs): super().__init__(**kwargs) self.EncDecAttention = TFT5Attention( config, has_relative_attention_bias=has_relative_attention_bias, name="EncDecAttention", ) self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") self.dropout = tf.keras.layers.Dropout(config.dropout_rate) def call( self, hidden_states, kv, attention_mask=None, position_bias=None, head_mask=None, past_key_value_state=None, query_length=None, use_cache=False, training=False, ): norm_x = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( norm_x, mask=attention_mask, kv=kv, position_bias=position_bias, head_mask=head_mask, past_key_value_state=past_key_value_state, query_length=query_length, use_cache=use_cache, training=training, ) y = attention_output[0] layer_output = hidden_states + self.dropout(y, training=training) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them return outputs class TFT5Block(tf.keras.layers.Layer): def __init__(self, config, has_relative_attention_bias=False, **kwargs): super().__init__(**kwargs) self.is_decoder = config.is_decoder self.layer = [] self.layer.append( TFT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._0",) ) if self.is_decoder: self.layer.append( TFT5LayerCrossAttention( config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._1", ) ) self.layer.append(TFT5LayerFF(config, name="layer_._{}".format(len(self.layer)))) def call( self, hidden_states, attention_mask=None, position_bias=None, encoder_hidden_states=None, encoder_attention_mask=None, encoder_decoder_position_bias=None, head_mask=None, past_key_value_state=None, use_cache=False, training=False, ): if past_key_value_state is not None: assert self.is_decoder, "Only decoder can use `past_key_value_states`" expected_num_past_key_value_states = 2 if encoder_hidden_states is None else 4 error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format( expected_num_past_key_value_states, "2 (past / key) for cross attention" if expected_num_past_key_value_states == 4 else "", len(past_key_value_state), ) assert len(past_key_value_state) == expected_num_past_key_value_states, error_message self_attn_past_key_value_state = past_key_value_state[:2] cross_attn_past_key_value_state = past_key_value_state[2:] else: self_attn_past_key_value_state, cross_attn_past_key_value_state = None, None self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, head_mask=head_mask, past_key_value_state=self_attn_past_key_value_state, use_cache=use_cache, training=training, ) hidden_states, present_key_value_state = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights if self.is_decoder and encoder_hidden_states is not None: # the actual query length is unknown for cross attention # if using past key value states. Need to inject it here if present_key_value_state is not None: query_length = shape_list(present_key_value_state[0])[2] else: query_length = None cross_attention_outputs = self.layer[1]( hidden_states, kv=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, head_mask=head_mask, past_key_value_state=cross_attn_past_key_value_state, query_length=query_length, use_cache=use_cache, training=training, ) hidden_states = cross_attention_outputs[0] # Combine self attn and cross attn key value states if present_key_value_state is not None: present_key_value_state = present_key_value_state + cross_attention_outputs[1] # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states, training=training) outputs = (hidden_states,) # Add attentions if we output them outputs = outputs + (present_key_value_state,) + attention_outputs return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) class _NoLayerEmbedTokens(object): """ this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer' class to avoid problem with weight restoring. Also it makes sure that the layer is called from the correct scope to avoid problem with saving/storing the correct weights """ def __init__(self, layer, abs_scope_name=None): self._layer = layer self._abs_scope_name = abs_scope_name def call(self, inputs, mode="embedding"): if self._abs_scope_name is None: return self._layer.call(inputs, mode) # if an abs scope name is given to the embedding variable, call variable from absolute scope with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name: with tf.name_scope(abs_scope_name.original_name_scope): return self._layer.call(inputs, mode) def __call__(self, inputs, mode="embedding"): if self._abs_scope_name is None: return self._layer(inputs, mode) # if an abs scope name is given to the embedding variable, call variable from absolute scope with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name: with tf.name_scope(abs_scope_name.original_name_scope): return self._layer(inputs, mode) #################################################### # The full model without a specific pretrained or finetuning head is # provided as a tf.keras.layers.Layer usually called "TFT5MainLayer" #################################################### class TFT5MainLayer(tf.keras.layers.Layer): def __init__(self, config, embed_tokens=None, **kwargs): super().__init__(**kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.embed_tokens = embed_tokens self.is_decoder = config.is_decoder self.config = config self.num_hidden_layers = config.num_layers self.block = [ TFT5Block(config, has_relative_attention_bias=bool(i == 0), name="block_._{}".format(i),) for i in range(config.num_layers) ] self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm") self.dropout = tf.keras.layers.Dropout(config.dropout_rate) def get_input_embeddings(self): return self.embed_tokens def get_output_embeddings(self): return self.embed_tokens def set_embed_tokens(self, embed_tokens): self.embed_tokens = embed_tokens def _resize_token_embeddings(self, new_num_tokens): raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models def _prune_heads(self, heads_to_prune): raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models def call( self, input_ids, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, inputs_embeds=None, head_mask=None, past_key_value_states=None, use_cache=False, training=False, ): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = shape_list(input_ids) input_ids = tf.reshape(input_ids, (-1, input_shape[-1])) elif inputs_embeds is not None: input_shape = shape_list(inputs_embeds)[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: assert self.embed_tokens is not None, "You have to intialize the model with valid token embeddings" inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = input_shape if past_key_value_states is not None: assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format( input_shape, (batch_size, 1) ) # required mask seq length can be calculated via length of past # key value states and seq_length = 1 for the last token mask_seq_length = shape_list(past_key_value_states[0][0])[2] + seq_length else: mask_seq_length = seq_length if attention_mask is None: attention_mask = tf.fill((batch_size, mask_seq_length), 1) if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: encoder_seq_length = shape_list(encoder_hidden_states)[1] encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1) # initialize past_key_value_states with `None` if past does not exist if past_key_value_states is None: past_key_value_states = [None] * len(self.block) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. attention_mask = tf.cast(attention_mask, dtype=tf.float32) num_dims_attention_mask = len(shape_list(attention_mask)) if num_dims_attention_mask == 3: extended_attention_mask = attention_mask[:, None, :, :] elif num_dims_attention_mask == 2: # Provided a padding mask of dimensions [batch_size, mask_seq_length] # - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] if self.is_decoder: seq_ids = tf.range(mask_seq_length) causal_mask = tf.less_equal( tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), seq_ids[None, :, None], ) causal_mask = tf.cast(causal_mask, dtype=tf.float32) extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] if past_key_value_states[0] is not None: extended_attention_mask = extended_attention_mask[:, :, -1:, :] else: extended_attention_mask = attention_mask[:, None, None, :] # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. # T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 # extended_attention_mask = tf.math.equal(extended_attention_mask, # tf.transpose(extended_attention_mask, perm=(-1, -2))) extended_attention_mask = (1.0 - extended_attention_mask) * -1e9 if self.is_decoder and encoder_attention_mask is not None: # If a 2D ou 3D attention mask is provided for the cross-attention # we need to make broadcastabe to [batch_size, num_heads, mask_seq_length, mask_seq_length] # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=tf.float32) num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) if num_dims_encoder_attention_mask == 3: encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] if num_dims_encoder_attention_mask == 2: encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] # T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 else: encoder_extended_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] if head_mask is not None: raise NotImplementedError else: head_mask = [None] * self.num_hidden_layers # head_mask = tf.constant([0] * self.num_hidden_layers) present_key_value_states = () all_hidden_states = () all_attentions = () position_bias = None encoder_decoder_position_bias = None hidden_states = self.dropout(inputs_embeds, training=training) for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_outputs = layer_module( hidden_states, attention_mask=extended_attention_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, head_mask=head_mask[i], past_key_value_state=past_key_value_state, use_cache=use_cache, training=training, ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) hidden_states, present_key_value_state = layer_outputs[:2] if i == 0: # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) position_bias = layer_outputs[3 if self.output_attentions else 2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3] # append next layer key value states present_key_value_states = present_key_value_states + (present_key_value_state,) if self.output_attentions: all_attentions = all_attentions + (layer_outputs[2],) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states, training=training) # Add last layer if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = (hidden_states,) if use_cache is True: assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self) outputs = outputs + (present_key_value_states,) if self.output_hidden_states: outputs = outputs + (all_hidden_states,) if self.output_attentions: outputs = outputs + (all_attentions,) return outputs # last-layer hidden state, (all hidden states), (all attentions) #################################################### # TFT5PreTrainedModel is a sub-class of tf.keras.Model # which take care of loading and saving pretrained weights # and various common utilities. # Here you just need to specify a few (self-explanatory) # pointers for your model. #################################################### class TFT5PreTrainedModel(TFPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = T5Config pretrained_model_archive_map = TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP base_model_prefix = "transformer" @property def dummy_inputs(self): input_ids = tf.constant(DUMMY_INPUTS) input_mask = tf.constant(DUMMY_MASK) dummy_inputs = { "inputs": input_ids, "decoder_input_ids": input_ids, "decoder_attention_mask": input_mask, } return dummy_inputs T5_START_DOCSTRING = r""" The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer`_ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text denoising generative setting. This model is a tf.keras.Model `tf.keras.Model`_ sub-class. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and behavior. .. _`Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer`: https://arxiv.org/abs/1910.10683 .. _`tf.keras.Model`: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model Note on the model inputs: TF 2.0 models accepts two formats as inputs: - having all inputs as keyword arguments (like PyTorch models), or - having all inputs as a list, tuple or dict in the first positional arguments. This second option is usefull when using `tf.keras.Model.fit()` method which currently requires having all the tensors in the first argument of the model call function: `model(inputs)`. If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the first positional argument : - a single Tensor with input_ids only and nothing else: `model(inputs_ids) - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - a dictionary with one or several input Tensors associaed to the input names given in the docstring: `model({'input_ids': input_ids, 'token_type_ids': token_type_ids})` Parameters: config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. """ T5_INPUTS_DOCSTRING = r""" Args: inputs are usually used as a `dict` (see T5 description above for more information) containing all the following. inputs (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you should be able to pad the inputs on the right or the left. Indices can be obtained using :class:`transformers.T5Tokenizer`. To know more on how to prepare :obj:`input_ids` for pre-training take a look at `T5 Training <./t5.html#training>`_ . See :func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation. If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`). attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. encoder_outputs (:obj:`tuple(tuple(tf.FloatTensor)`, `optional`, defaults to :obj:`None`): Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`) `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains pre-computed key and value hidden-states of the attention blocks. Can be used to speed up decoding. If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): If `use_cache` is True, `decoder_past_key_value_states` are returned and can be used to speed up decoding (see `decoder_past_key_value_states`). inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. decoder_inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at `T5 Training <./t5.html#training>`_ . head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. """ @add_start_docstrings( "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.", T5_START_DOCSTRING, ) class TFT5Model(TFT5PreTrainedModel): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared") # retrieve correct absolute scope for embed token wrapper with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: pass embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name) encoder_config = copy.deepcopy(config) self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder") decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder") def get_input_embeddings(self): return self.shared def get_output_embeddings(self): return self.shared def get_encoder(self): return self.encoder def get_decoder(self): return self.decoder @add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING) def call(self, inputs, **kwargs): r""" Return: :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs. last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. If `decoder_past_key_value_states` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output. decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``): Contains pre-computed key and value hidden-states of the attention blocks. Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input). Note that when using `decoder_past_key_value_states`, the model only outputs the last `hidden-state` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`. hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``): Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Examples:: from transformers import T5Tokenizer, TFT5Model tokenizer = T5Tokenizer.from_pretrained('t5-small') model = TFT5Model.from_pretrained('t5-small') input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="tf") # Batch size 1 outputs = model(input_ids, decoder_input_ids=input_ids) last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple """ if isinstance(inputs, dict): kwargs.update(inputs) else: kwargs["inputs"] = inputs # retrieve arguments input_ids = kwargs.get("inputs", None) inputs_embeds = kwargs.get("inputs_embeds", None) attention_mask = kwargs.get("attention_mask", None) encoder_outputs = kwargs.get("encoder_outputs", None) decoder_input_ids = kwargs.get("decoder_input_ids", None) decoder_attention_mask = kwargs.get("decoder_attention_mask", None) decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None) decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None) use_cache = kwargs.get("use_cache", True) head_mask = kwargs.get("head_mask", None) # Encode if needed (training, first prediction pass) if encoder_outputs is None: encoder_outputs = self.encoder( input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, ) hidden_states = encoder_outputs[0] # If decoding with past key value states, only the last tokens # should be given as an input if decoder_past_key_value_states is not None: if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] # Decode decoder_outputs = self.decoder( decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_value_states=decoder_past_key_value_states, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, ) if use_cache is True: past = ((encoder_outputs, decoder_outputs[1]),) decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:] return decoder_outputs + encoder_outputs @add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) class TFT5ForConditionalGeneration(TFT5PreTrainedModel): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.model_dim = config.d_model self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared") # retrieve correct absolute scope for embed token wrapper with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: pass embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name) encoder_config = copy.deepcopy(config) self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder") decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder") def get_input_embeddings(self): return self.shared def get_output_embeddings(self): return self.shared def get_encoder(self): return self.encoder def get_decoder(self): return self.decoder @add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING) def call(self, inputs, **kwargs): r""" Return: :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs. loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`lm_label` is provided): Classification loss (cross entropy). prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``): Contains pre-computed key and value hidden-states of the attention blocks. Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input). Note that when using `decoder_past_key_value_states`, the model only outputs the last `prediction_score` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`. hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``): Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention. Examples:: from transformers import T5Tokenizer, TFT5ForConditionalGeneration tokenizer = T5Tokenizer.from_pretrained('t5-small') model = TFT5ForConditionalGeneration.from_pretrained('t5-small') input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="tf") # Batch size 1 outputs = model(input_ids, decoder_input_ids=input_ids) prediction_scores = outputs[0] tokenizer = T5Tokenizer.from_pretrained('t5-small') model = TFT5ForConditionalGeneration.from_pretrained('t5-small') input_ids = tokenizer.encode("summarize: Hello, my dog is cute", return_tensors="tf") # Batch size 1 model.generate(input_ids) """ if isinstance(inputs, dict): kwargs.update(inputs) else: kwargs["inputs"] = inputs # retrieve arguments input_ids = kwargs.get("inputs", None) decoder_input_ids = kwargs.get("decoder_input_ids", None) attention_mask = kwargs.get("attention_mask", None) encoder_outputs = kwargs.get("encoder_outputs", None) decoder_attention_mask = kwargs.get("decoder_attention_mask", None) decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None) use_cache = kwargs.get("use_cache", True) inputs_embeds = kwargs.get("inputs_embeds", None) decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None) head_mask = kwargs.get("head_mask", None) # Encode if needed (training, first prediction pass) if encoder_outputs is None: # Convert encoder inputs in embeddings if needed encoder_outputs = self.encoder( input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, ) hidden_states = encoder_outputs[0] # If decoding with past key value states, only the last tokens # should be given as an input if decoder_past_key_value_states is not None: if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] # Decode decoder_outputs = self.decoder( decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_value_states=decoder_past_key_value_states, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, ) # insert decoder past at right place # to speed up decoding if use_cache is True: past = ((encoder_outputs, decoder_outputs[1]),) decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:] sequence_output = decoder_outputs[0] * (self.model_dim ** -0.5) embed_tokens = self.get_output_embeddings() lm_logits = embed_tokens(sequence_output, mode="linear") decoder_outputs = (lm_logits,) + decoder_outputs[1:] return decoder_outputs + encoder_outputs def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs): assert past is not None, "past has to be defined for encoder_outputs" # first step if len(past) < 2: encoder_outputs, decoder_past_key_value_states = past, None else: encoder_outputs, decoder_past_key_value_states = past[0], past[1] return { "inputs": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy "decoder_input_ids": input_ids, # input_ids are the decoder_input_ids "decoder_past_key_value_states": decoder_past_key_value_states, "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, "use_cache": use_cache, } def _reorder_cache(self, past, beam_idx): # if decoder past is not included in output # speedy decoding is disabled and no need to reorder if len(past) < 2: logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") return past decoder_past = past[1] past = (past[0],) reordered_decoder_past = () for layer_past_states in decoder_past: # get the correct batch idx from layer past batch dim # batch dim of `past` is at 2nd position reordered_layer_past_states = () for layer_past_state in layer_past_states: # need to set correct `past` for each of the four key / value states reordered_layer_past_states = reordered_layer_past_states + (tf.gather(layer_past_state, beam_idx),) assert shape_list(reordered_layer_past_states[0]) == shape_list(layer_past_states[0]) assert len(reordered_layer_past_states) == len(layer_past_states) reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) return past + (reordered_decoder_past,)