U
    &c                     @   s6  d Z ddlZddlZddlZddlmZ ddlmZ ddlmZ ddl	m
Z
 ddlmZmZ ddlmZmZmZmZ eeZd	d
ddddZdd ZG dd dejZG dd dejZG dd dejZG dd deZdZdZedeG dd deZedeG dd  d eZ ed!eG d"d# d#eZ!dS )$zPyTorch OpenAI GPT-2 model.    N)CrossEntropyLoss   )ACT2FN)
GPT2Config)add_start_docstrings add_start_docstrings_to_callable)Conv1DPreTrainedModelSequenceSummaryprune_conv1d_layerz1https://cdn.huggingface.co/gpt2-pytorch_model.binz8https://cdn.huggingface.co/gpt2-medium-pytorch_model.binz7https://cdn.huggingface.co/gpt2-large-pytorch_model.binz4https://cdn.huggingface.co/gpt2-xl-pytorch_model.binz7https://cdn.huggingface.co/distilgpt2-pytorch_model.bin)Zgpt2zgpt2-mediumz
gpt2-largezgpt2-xlZ
distilgpt2c                 C   s   zddl }ddl}W n  tk
r4   td  Y nX tj|}td	| |j
|}g }g }|D ]@\}	}
td	|	|
 |j
||	}||	 ||  qjt||D ]b\}	}|	dd }	|	d}	| }|	D ]}|d|r|d	|}n|g}|d d
ks|d dkr*t|d}n^|d dkrDt|d}nD|d dks`|d dkrzt||d }t|d}nt||d }t|dkrt|d }|| }qz|j|jkstW n< tk
r } z| j|j|jf7  _ W 5 d}~X Y nX td	|	 t||_q| S )z, Load tf checkpoints in a pytorch model
    r   NzLoading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.z(Converting TensorFlow checkpoint from {}z"Loading TF weight {} with shape {}   /z[A-Za-z]+\d+z(\d+)wgweightbbiaswpewte   r   zInitialize PyTorch weight {})reZ
tensorflowImportErrorloggererrorospathabspathinfoformatZtrainZlist_variablesZload_variableappendsqueezezipsplit	fullmatchgetattrlenintshapeAssertionErrorargstorchZ
from_numpydata)modelconfigZgpt2_checkpoint_pathr   tfZtf_pathZ	init_varsnamesZarraysnamer'   arraypointerZm_nameZscope_namesnume r5   >/tmp/pip-unpacked-wheel-ymerj3tt/transformers/modeling_gpt2.pyload_tf_weights_in_gpt2+   sZ    


r7   c                       sL   e Zd Zd fdd	Zdd ZdddZd	d
 ZdddZdddZ  Z	S )	AttentionFc              	      s   t    |j| _|}||j dks(t| dttj||ftj	d
dd|| | dtd |j| _|| _|| _t|d || _t||| _t|j| _t|j| _t | _d S )Nr   r   dtyper   masked_bias        )super__init__output_attentionsn_headr(   Zregister_bufferr*   ZtrilonesZuint8viewZtensor
split_sizescaler   c_attnc_projnnDropoutZ
attn_pdropattn_dropoutresid_pdropresid_dropoutsetpruned_heads)selfnxn_ctxr-   rE   n_state	__class__r5   r6   r?   c   s"    
 $zAttention.__init__c                    s  t |dkrd S t| j| j| j }t|| j }|D ](  t fdd| jD   d| < q8|d	 
d}tt ||  }t||| j |d| j  g}t| j|dd| _t| j|dd| _| j| j | jt |  | _| jt | | _| j|| _d S )Nr   c                 3   s   | ]}| k rd ndV  qdS )r   r   Nr5   ).0hheadr5   r6   	<genexpr>   s     z(Attention.prune_heads.<locals>.<genexpr>r   r   Zdim)r%   r*   rB   rA   rD   rM   rN   sumrC   
contiguouseqarangelongcatr   rF   rG   union)rO   headsmaskindexZ
index_attnr5   rW   r6   prune_headsx   s    
 zAttention.prune_headsNc                 C   s   t ||}| jr(|t|dd  }|d|d }}| jd d d d || |d |f }	t |	 || j	|j
}|d k	r|| }tjdd|}| |}|d k	r|| }t ||g}
| jr|
| |
S )NrZ   g      ?r[   )r*   matmulrE   floatsizer   whereboolr;   tor:   rH   ZSoftmaxrJ   r@   r   )rO   qkvattention_mask	head_maskr   Zndnsrd   outputsr5   r5   r6   _attn   s     &

zAttention._attnc                 C   sD   | dddd }| d d |d|d f }|j| S )Nr   r   r   r=   rg   rZ   )permuter]   rj   rC   )rO   xnew_x_shaper5   r5   r6   merge_heads   s    &zAttention.merge_headsc                 C   sX   |  d d | j| d| j f }|j| }|rD|ddddS |ddddS d S )NrZ   r   r   r=   r   )rj   rA   rC   rv   )rO   rw   ro   rx   r5   r5   r6   split_heads   s
    &
zAttention.split_headsc                 C   s   |  |}|j| jdd\}}}| |}| j|dd}| |}|d k	r|d dd|d  }	}
tj|	|fdd}tj|
|fdd}|dkrt|dd|f}nd	}| |||||}|d }| 	|}| 
|}| |}||g|dd   }|S )
Nr   r[   T)ro   r   rg   rZ   r   N)rF   r"   rD   rz   Z	transposer*   ra   stackru   ry   rG   rL   )rO   rw   
layer_pastrq   rr   	use_cachequerykeyvalueZpast_keyZ
past_valuepresentZattn_outputsart   r5   r5   r6   forward   s&    





zAttention.forward)F)NN)F)NNNF)
__name__
__module____qualname__r?   rf   ru   ry   rz   r   __classcell__r5   r5   rS   r6   r8   b   s   

r8   c                       s$   e Zd Z fddZdd Z  ZS )MLPc                    sF   t    |j}t||| _t||| _t|j | _t	
|j| _d S r{   )r>   r?   n_embdr   c_fcrG   r   Zactivation_functionactrH   rI   rK   dropout)rO   rR   r-   rP   rS   r5   r6   r?      s    
zMLP.__init__c                 C   s$   |  | |}| |}| |S r{   )r   r   rG   r   )rO   rw   rV   h2r5   r5   r6   r      s    
zMLP.forwardr   r   r   r?   r   r   r5   r5   rS   r6   r      s   r   c                       s(   e Zd Zd fdd	ZdddZ  ZS )	BlockFc                    sX   t    |j}tj||jd| _t||||| _tj||jd| _	t
d| || _d S )NZeps   )r>   r?   r   rH   	LayerNormlayer_norm_epsilonln_1r8   attnln_2r   mlp)rO   rQ   r-   rE   rP   rS   r5   r6   r?      s    
zBlock.__init__Nc           
      C   sX   | j | |||||d}|d }|| }| | |}|| }|g|dd   }	|	S )Nr}   rq   rr   r~   r   r   )r   r   r   r   )
rO   rw   r}   rq   rr   r~   Zoutput_attnr   mrt   r5   r5   r6   r      s    zBlock.forward)F)NNNFr   r5   r5   rS   r6   r      s   r   c                       s8   e Zd ZdZeZeZeZ	dZ
 fddZdd Z  ZS )GPT2PreTrainedModelz An abstract class to handle weights initialization and
        a simple interface for downloading and loading pretrained models.
    transformerc                    s   t  j|| d S r{   )r>   r?   )rO   inputskwargsrS   r5   r6   r?     s    zGPT2PreTrainedModel.__init__c                 C   s|   t |tjtjtfrR|jjjd| jj	d t |tjtfrx|j
dk	rx|j
j  n&t |tjrx|j
j  |jjd dS )z! Initialize the weights.
        g        )ZmeanZstdN      ?)
isinstancerH   Linear	Embeddingr   r   r+   Znormal_r-   Zinitializer_ranger   Zzero_r   Zfill_)rO   moduler5   r5   r6   _init_weights  s    z!GPT2PreTrainedModel._init_weights)r   r   r   __doc__r   Zconfig_class!GPT2_PRETRAINED_MODEL_ARCHIVE_MAPZpretrained_model_archive_mapr7   Zload_tf_weightsZbase_model_prefixr?   r   r   r5   r5   rS   r6   r      s   r   an  

    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
    usage and behavior.

    Parameters:
        config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
a  
    Args:
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.
            If `past` is used, optionally only the last `input_ids` have to be input (see `past`).

            Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
            :func:`transformers.PreTrainedTokenizer.encode_plus` for details.

            `What are input IDs? <../glossary.html#input-ids>`__

        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
            Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
            (see `past` output below). Can be used to speed up sequential decoding.
            If `past` is used, the user can optionally input only the last `input_ids` (those that don't have their past given to this model) of shape :obj:`(batch_size, 1)` instead of all `input_ids` of shape :obj:`(batch_size, sequence_length)`.
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.

            `What are attention masks? <../glossary.html#attention-mask>`__
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`, defaults to :obj:`None`):
            `input_ids_length` = `sequence_length if `past` is None else 1
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
            If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).

            `What are token type IDs? <../glossary.html#token-type-ids>`_
        position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.

            `What are position IDs? <../glossary.html#position-ids>`_
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
        input_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
            If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
        use_cache (:obj:`bool`):
            If `use_cache` is True, `past` key value states are returned and can be used to speed up decoding (see `past`). Defaults to `True`.
z^The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.c                	       sF   e Zd Z fddZdd Zdd Zdd ZeedddZ	  Z
S )	GPT2Modelc                    s   t     j| _ j| _t j j| _t j	 j| _
t j| _t fddt jD | _tj j jd| _|   d S )Nc                    s   g | ]}t  j d dqS )T)rE   )r   rQ   )rU   _r-   r5   r6   
<listcomp>]  s     z&GPT2Model.__init__.<locals>.<listcomp>r   )r>   r?   output_hidden_statesr@   rH   r   
vocab_sizer   r   Zn_positionsr   rI   Z
embd_pdropdropZ
ModuleListrangen_layerrV   r   r   ln_finit_weightsrO   r-   rS   r   r6   r?   U  s     zGPT2Model.__init__c                 C   s   | j S r{   r   rO   r5   r5   r6   get_input_embeddingsb  s    zGPT2Model.get_input_embeddingsc                 C   s
   || _ d S r{   r   )rO   Znew_embeddingsr5   r5   r6   set_input_embeddingse  s    zGPT2Model.set_input_embeddingsc                 C   s(   |  D ]\}}| j| j| qdS )zz Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        N)itemsrV   r   rf   )rO   Zheads_to_pruneZlayerrc   r5   r5   r6   _prune_headsh  s    zGPT2Model._prune_headsNTc	                    s~  |dk	r\|dk	r$|ddddf }|dk	r@|ddddf }|dk	r\|ddddf }|dk	rv|dk	rvt dnX|dk	r| }	|d|	d }|jd }
n,|dk	r| dd }	|jd }
nt d|dk	r|d|	d }|dk	r|d|	d }|dkrd}dgt| j }n|d d d}|dkr|dk	rJ|jn|j}tj||	d | tj	|d}|
dd|	d }|dk	r|
dkstd||
d}|
d	
d
}|jt|  jd}d| d }| || jj}|dkr | |}| |}|dk	r | |}nd}|| | }| |}|	|df }d}g }d}tt| j|D ]v\}\}}| jr||j| f }|||||| |d}|dd
 \}}|dkr||f }| jrf||d
  qf| |}|j| }| jr||f }|f}|dkr||f }| jr0||f }| jrz|	dd d |d jdd   t fdd|D }||f }|S )aO  
    Return:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
        last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the last layer of the model.
            If `past` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding.
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

        from transformers import GPT2Tokenizer, GPT2Model
        import torch

        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2Model.from_pretrained('gpt2')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

        NrZ   zDYou cannot specify both input_ids and inputs_embeds at the same timer   z5You have to specify either input_ids or inputs_embedsrg   )r:   devicez$batch_size has to be defined and > 0r   r   r9   r   r<   r5   r   T)rZ   c                 3   s   | ]}|j   V  qd S r{   )rC   )rU   tZattention_output_shaper5   r6   rY     s     z$GPT2Model.forward.<locals>.<genexpr>)
ValueErrorrj   rC   r'   r%   rV   r   r*   r_   r`   	unsqueezer(   rm   next
parametersr:   Zget_head_maskr-   r   r   r   r   	enumerater!   r   r@   r   r   tuple)rO   	input_idspastrq   token_type_idsposition_idsrr   inputs_embedsr~   Zinput_shapeZ
batch_sizeZpast_lengthr   Zposition_embedsZtoken_type_embedshidden_statesZoutput_shapeZpresentsZall_attentionsZall_hidden_statesiblockr}   rt   r   r5   r   r6   r   o  s    0
















"
zGPT2Model.forward)NNNNNNNT)r   r   r   r?   r   r   r   r   GPT2_INPUTS_DOCSTRINGr   r   r5   r5   rS   r6   r   P  s           r   z~The GPT2 Model transformer with a language modeling head on top
    (linear layer with weights tied to the input embeddings). c                
       s>   e Zd Z fddZdd Zdd Zeedd	d
Z  Z	S )GPT2LMHeadModelc                    s8   t  | t|| _tj|j|jdd| _| 	  d S )NFr   )
r>   r?   r   r   rH   r   r   r   lm_headr   r   rS   r5   r6   r?     s    
zGPT2LMHeadModel.__init__c                 C   s   | j S r{   r   r   r5   r5   r6   get_output_embeddings  s    z%GPT2LMHeadModel.get_output_embeddingsc                 K   s*   |r|d d df  d}|||d dS )NrZ   r~   )r   r   r~   )r   )rO   r   r   r   r5   r5   r6   prepare_inputs_for_generation!  s    z-GPT2LMHeadModel.prepare_inputs_for_generationNTc
              
   C   s   | j ||||||||	d}
|
d }| |}|f|
dd  }|dk	r|dddddf  }|dddf  }t }||d|d|d}|f| }|S )aS
  
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for language modeling.
            Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
            Indices are selected in ``[-100, 0, ..., config.vocab_size]``
            All labels set to ``-100`` are ignored (masked), the loss is only
            computed for labels in ``[0, ..., config.vocab_size]``

    Return:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
        loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
            Language modeling loss.
        prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding.
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

        import torch
        from transformers import GPT2Tokenizer, GPT2LMHeadModel

        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2LMHeadModel.from_pretrained('gpt2')

        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=input_ids)
        loss, logits = outputs[:2]

        r   rq   r   r   rr   r   r~   r   r   N.rZ   )r   r   r]   r   rC   rj   )rO   r   r   rq   r   r   rr   r   labelsr~   transformer_outputsr   	lm_logitsrt   shift_logitsshift_labelsloss_fctlossr5   r5   r6   r   (  s(    7


zGPT2LMHeadModel.forward)	NNNNNNNNT)
r   r   r   r?   r   r   r   r   r   r   r5   r5   rS   r6   r     s            r   ar  The GPT2 Model transformer with a language modeling and a multiple-choice classification
    head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
    The language modeling head has its weights tied to the input embeddings,
    the classification head takes as input the input of a specified classification token index in the input sequence).
c                       s6   e Zd Z fddZdd Zeed	ddZ  ZS )
GPT2DoubleHeadsModelc                    sH   t  | d|_t|| _tj|j|jdd| _	t
|| _|   d S )Nr   Fr   )r>   r?   Z
num_labelsr   r   rH   r   r   r   r   r
   multiple_choice_headr   r   rS   r5   r6   r?     s    

zGPT2DoubleHeadsModel.__init__c                 C   s   | j S r{   r   r   r5   r5   r6   r     s    z*GPT2DoubleHeadsModel.get_output_embeddingsNTc              
   C   s   | j ||||||||d}|d }| |}| ||d}||f|dd  }|
dk	rt }||d|d|
d}|f| }|	dk	r|dddddf  }|	dddf  }t }||d|d|d}|f| }|S )a2  
        mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
            Index of the classification token in each input sequence.
            Selected in the range ``[0, input_ids.size(-1) - 1[``.
        lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`)
            Labels for language modeling.
            Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
            Indices are selected in ``[-1, 0, ..., config.vocab_size]``
            All labels set to ``-100`` are ignored (masked), the loss is only
            computed for labels in ``[0, ..., config.vocab_size]``
        mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`, defaults to :obj:`None`)
            Labels for computing the multiple choice classification loss.
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above)

    Return:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
        lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``lm_labels`` is provided):
            Language modeling loss.
        mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`multiple_choice_labels` is provided):
            Multiple choice classification loss.
        lm_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        mc_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
            Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding.
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

        import torch
        from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel

        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2DoubleHeadsModel.from_pretrained('gpt2')

        # Add a [CLS] to the vocabulary (we should train it also!)
        tokenizer.add_special_tokens({'cls_token': '[CLS]'})
        model.resize_token_embeddings(len(tokenizer))  # Update the model embeddings with the new vocabulary size
        print(tokenizer.cls_token_id, len(tokenizer))  # The newly token the last token of the vocabulary

        choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
        encoded_choices = [tokenizer.encode(s) for s in choices]
        cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]

        input_ids = torch.tensor(encoded_choices).unsqueeze(0)  # Batch size: 1, number of choices: 2
        mc_token_ids = torch.tensor([cls_token_location])  # Batch size: 1

        outputs = model(input_ids, mc_token_ids=mc_token_ids)
        lm_prediction_scores, mc_prediction_scores = outputs[:2]

        r   r   rZ   r   N.)r   r   r   r    r   rC   rj   r]   )rO   r   r   rq   r   r   rr   r   Zmc_token_idsZ	lm_labelsZ	mc_labelsr~   r   r   r   Z	mc_logitsrt   r   r   r   r   r5   r5   r6   r     s2    O


zGPT2DoubleHeadsModel.forward)NNNNNNNNNNT)	r   r   r   r?   r   r   r   r   r   r5   r5   rS   r6   r   z  s   		           r   )"r   loggingr   r*   Ztorch.nnrH   r   Zactivationsr   Zconfiguration_gpt2r   Z
file_utilsr   r   Zmodeling_utilsr   r	   r
   r   	getLoggerr   r   r   r7   Moduler8   r   r   r   ZGPT2_START_DOCSTRINGr   r   r   r   r5   r5   r5   r6   <module>   sN   
	7l1 >d