from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Dict, List, NewType, Tuple import torch from torch.nn.utils.rnn import pad_sequence from ..tokenization_utils import PreTrainedTokenizer class DataCollator(ABC): """ A `DataCollator` is responsible for batching and pre-processing samples of data as requested by the training loop. """ @abstractmethod def collate_batch(self) -> Dict[str, torch.Tensor]: """ Take a list of samples from a Dataset and collate them into a batch. Returns: A dictionary of tensors """ pass InputDataClass = NewType("InputDataClass", Any) @dataclass class DefaultDataCollator(DataCollator): """ Very simple data collator that: - simply collates batches of dict-like objects - Performs special handling for potential keys named: - `label`: handles a single value (int or float) per object - `label_ids`: handles a list of values per object - does not do any additional preprocessing i.e., Property names of the input object will be used as corresponding inputs to the model. See glue and ner for example of how it's useful. """ def collate_batch(self, features: List[InputDataClass]) -> Dict[str, torch.Tensor]: # In this method we'll make the assumption that all `features` in the batch # have the same attributes. # So we will look at the first element as a proxy for what attributes exist # on the whole batch. first = features[0] # Special handling for labels. # Ensure that tensor is created with the correct type # (it should be automatically the case, but let's make sure of it.) if hasattr(first, "label") and first.label is not None: if type(first.label) is int: labels = torch.tensor([f.label for f in features], dtype=torch.long) else: labels = torch.tensor([f.label for f in features], dtype=torch.float) batch = {"labels": labels} elif hasattr(first, "label_ids") and first.label_ids is not None: if type(first.label_ids[0]) is int: labels = torch.tensor([f.label_ids for f in features], dtype=torch.long) else: labels = torch.tensor([f.label_ids for f in features], dtype=torch.float) batch = {"labels": labels} else: batch = {} # Handling of all other possible attributes. # Again, we will use the first element to figure out which key/values are not None for this model. for k, v in vars(first).items(): if k not in ("label", "label_ids") and v is not None and not isinstance(v, str): batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long) return batch @dataclass class DataCollatorForLanguageModeling(DataCollator): """ Data collator used for language modeling. - collates batches of tensors, honoring their tokenizer's pad_token - preprocesses batches for masked language modeling """ tokenizer: PreTrainedTokenizer mlm: bool = True mlm_probability: float = 0.15 def collate_batch(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]: batch = self._tensorize_batch(examples) if self.mlm: inputs, labels = self.mask_tokens(batch) return {"input_ids": inputs, "masked_lm_labels": labels} else: return {"input_ids": batch, "labels": batch} def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor: length_of_first = examples[0].size(0) are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) if are_tensors_same_length: return torch.stack(examples, dim=0) else: if self.tokenizer._pad_token is None: raise ValueError( "You are attempting to pad samples but the tokenizer you are using" f" ({self.tokenizer.__class__.__name__}) does not have one." ) return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id) def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ if self.tokenizer.mask_token is None: raise ValueError( "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer." ) labels = inputs.clone() # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) probability_matrix = torch.full(labels.shape, self.mlm_probability) special_tokens_mask = [ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() ] probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) if self.tokenizer._pad_token is not None: padding_mask = labels.eq(self.tokenizer.pad_token_id) probability_matrix.masked_fill_(padding_mask, value=0.0) masked_indices = torch.bernoulli(probability_matrix).bool() labels[~masked_indices] = -100 # We only compute loss on masked tokens # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) # 10% of the time, we replace masked input tokens with random word indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) inputs[indices_random] = random_words[indices_random] # The rest of the time (10% of the time) we keep the masked input tokens unchanged return inputs, labels