Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up data collators and datasets #8308

Merged
merged 3 commits into from
Nov 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions examples/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,15 @@ def main():
def tokenize_function(examples):
# Remove empty lines
examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]
return tokenizer(examples["text"], padding=padding, truncation=True, max_length=data_args.max_seq_length)
return tokenizer(
examples["text"],
padding=padding,
truncation=True,
max_length=data_args.max_seq_length,
# We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
# receives the `special_tokens_mask`.
return_special_tokens_mask=True,
)

tokenized_datasets = datasets.map(
tokenize_function,
Expand All @@ -275,8 +283,10 @@ def tokenize_function(examples):
)
else:
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
# efficient when it receives the `special_tokens_mask`.
def tokenize_function(examples):
return tokenizer(examples[text_column_name])
return tokenizer(examples[text_column_name], return_special_tokens_mask=True)

tokenized_datasets = datasets.map(
tokenize_function,
Expand Down
1 change: 0 additions & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@
from .data.data_collator import (
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForNextSentencePrediction,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSOP,
DataCollatorForTokenClassification,
Expand Down
266 changes: 90 additions & 176 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union

Expand Down Expand Up @@ -175,72 +176,111 @@ def __call__(self, features):
return batch


def _collate_batch(examples, tokenizer):
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
# Tensorize if necessary.
if isinstance(examples[0], (list, tuple)):
examples = [torch.tensor(e, dtype=torch.long) for e in examples]

# Check if padding is necessary.
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
return torch.stack(examples, dim=0)

# If yes, check if we have a `pad_token`.
if tokenizer._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the tokenizer you are using"
f" ({tokenizer.__class__.__name__}) does not have a pad token."
)

# Creating the full tensor and filling it with our data.
max_length = max(x.size(0) for x in examples)
result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
for i, example in enumerate(examples):
if tokenizer.padding_side == "right":
result[i, : example.shape[0]] = example
else:
result[i, -example.shape[0] :] = example
return result


@dataclass
class DataCollatorForLanguageModeling:
"""
Data collator used for language modeling.
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
are not all of the same length.

- collates batches of tensors, honoring their tokenizer's pad_token
- preprocesses batches for masked language modeling
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
mlm (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the
inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for
non-masked tokens and the value to predict for the masked token.
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.

.. note::

For best performance, this data collator should be used with a dataset having items that are dictionaries or
BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
:class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
argument :obj:`return_special_tokens_mask=True`.
"""

tokenizer: PreTrainedTokenizerBase
mlm: bool = True
mlm_probability: float = 0.15

def __post_init__(self):
if self.mlm and self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
"You should pass `mlm=False` to train on causal language modeling instead."
)

def __call__(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
# Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples]
batch = self._tensorize_batch(examples)
batch = self.tokenizer.pad(examples, return_tensors="pt")
else:
batch = {"input_ids": _collate_batch(examples, self.tokenizer)}

# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None)
if self.mlm:
inputs, labels = self.mask_tokens(batch)
return {"input_ids": inputs, "labels": labels}
batch["input_ids"], batch["labels"] = self.mask_tokens(
batch["input_ids"], special_tokens_mask=special_tokens_mask
)
else:
labels = batch.clone().detach()
labels = batch["input_ids"]
if self.tokenizer.pad_token_id is not None:
labels[labels == self.tokenizer.pad_token_id] = -100
return {"input_ids": batch, "labels": labels}

def _tensorize_batch(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> torch.Tensor:
# In order to accept both lists of lists and lists of Tensors
if isinstance(examples[0], (list, tuple)):
examples = [torch.tensor(e, dtype=torch.long) for e in examples]
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
return torch.stack(examples, dim=0)
else:
if self.tokenizer._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the tokenizer you are using"
f" ({self.tokenizer.__class__.__name__}) does not have one."
)
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
batch["labels"] = labels
return batch

def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def mask_tokens(
self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""

if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)

labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
if special_tokens_mask is None:
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = special_tokens_mask.bool()

probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens

Expand Down Expand Up @@ -385,9 +425,16 @@ class DataCollatorForSOP(DataCollatorForLanguageModeling):
- preprocesses batches for both masked language modeling and sentence order prediction
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"DataCollatorForSOP is deprecated and will be removed in a future version, you can now use "
"DataCollatorForLanguageModeling instead.",
FutureWarning,
)

def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
input_ids = [example["input_ids"] for example in examples]
input_ids = self._tensorize_batch(input_ids)
input_ids = _collate_batch(input_ids, self.tokenizer)
input_ids, labels, attention_mask = self.mask_tokens(input_ids)

token_type_ids = [example["token_type_ids"] for example in examples]
Expand Down Expand Up @@ -582,136 +629,3 @@ def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor,
) & masked_indices[i]

return inputs.long(), perm_mask, target_mapping, labels.long()


@dataclass
class DataCollatorForNextSentencePrediction:
"""
Data collator used for next sentence prediction. - collates examples which contains pre-generated negative examples
- preprocesses batches for masked language modeling
"""

tokenizer: PreTrainedTokenizerBase
mlm: bool = True
block_size: int = 512
short_seq_probability: float = 0.1
nsp_probability: float = 0.5
mlm_probability: float = 0.15

def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
"""
The input should contain negative examples, :class:`~transformers.DataCollatorForNextSentencePrediction` will
not generate any negative examples

Args:
examples (:obj:`List[Dict]`): Each dictionary should have the following keys:

- ``tokens_a``: A sequence of tokens, which should appear before ``tokens_b`` in the text.
- ``tokens_b``: A sequence of tokens, which should appear after ``tokens_a`` in the text.
- ``is_random_next``: 1 if this pair is generated randomly, else 0.
"""

tokens_a = [e["tokens_a"] for e in examples]
tokens_b = [e["tokens_b"] for e in examples]
nsp_labels = [1 if e["is_random_next"] else 0 for e in examples]

input_ids = []
segment_ids = []
attention_masks = []

assert len(tokens_a) == len(tokens_b)
for i in range(len(tokens_a)):
input_id, attention_mask, segment_id = self.create_features_from_example(tokens_a[i], tokens_b[i])
input_ids.append(input_id)
segment_ids.append(segment_id)
attention_masks.append(attention_mask)
if self.mlm:
input_ids, mlm_labels = self.mask_tokens(self._tensorize_batch(input_ids))
else:
input_ids = self._tensorize_batch(input_ids)

result = {
"input_ids": input_ids,
"attention_mask": self._tensorize_batch(attention_masks),
"token_type_ids": self._tensorize_batch(segment_ids),
"labels": mlm_labels if self.mlm else None,
"next_sentence_label": torch.tensor(nsp_labels),
}
return result

def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
return torch.stack(examples, dim=0)
else:
if self.tokenizer._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the tokenizer you are using"
f" ({self.tokenizer.__class__.__name__}) does not have one."
)
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)

def create_features_from_example(self, tokens_a, tokens_b):
"""Creates examples for a single document."""

max_num_tokens = self.block_size - self.tokenizer.num_special_tokens_to_add(pair=True)

tokens_a, tokens_b, _ = self.tokenizer.truncate_sequences(
tokens_a,
tokens_b,
num_tokens_to_remove=len(tokens_a) + len(tokens_b) - max_num_tokens,
truncation_strategy="longest_first",
)

input_id = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
attention_mask = [1] * len(input_id)
segment_id = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
assert len(input_id) <= self.block_size

# pad
while len(input_id) < self.block_size:
input_id.append(0)
attention_mask.append(0)
segment_id.append(0)

input_id = torch.tensor(input_id)
attention_mask = torch.tensor(attention_mask)
segment_id = torch.tensor(segment_id)

return input_id, attention_mask, segment_id

def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""

if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)

labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
Loading