Skip to content

Commit

Permalink
Trainer multi label (#7191)
Browse files Browse the repository at this point in the history
* Trainer accep multiple labels

* Missing import

* Fix dosctrings
  • Loading branch information
sgugger authored Sep 17, 2020
1 parent 7097459 commit 492bb6a
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 29 deletions.
45 changes: 30 additions & 15 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
run_hp_search_optuna,
run_hp_search_ray,
)
from .modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
from .tokenization_utils_base import PreTrainedTokenizerBase
Expand All @@ -45,6 +46,9 @@
default_hp_space,
distributed_broadcast_scalars,
distributed_concat,
nested_concat,
nested_numpify,
nested_xla_mesh_reduce,
set_seed,
)
from .training_args import TrainingArguments
Expand Down Expand Up @@ -293,6 +297,12 @@ def __init__(
self.scaler = torch.cuda.amp.GradScaler()
self.hp_search_backend = None
self.use_tune_checkpoints = False
if self.args.label_names is None:
self.args.label_names = (
["start_positions, end_positions"]
if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values()
else ["labels"]
)

def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
Expand Down Expand Up @@ -1307,9 +1317,9 @@ def prediction_loop(
if loss is not None:
eval_losses.extend([loss] * batch_size)
if logits is not None:
preds = logits if preds is None else tuple(torch.cat((p, l), dim=0) for p, l in zip(preds, logits))
preds = logits if preds is None else nested_concat(preds, logits, dim=0)
if labels is not None:
label_ids = labels if label_ids is None else torch.cat((label_ids, labels), dim=0)
label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)

if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
Expand All @@ -1318,25 +1328,23 @@ def prediction_loop(
if self.args.local_rank != -1:
# In distributed mode, concatenate all results from all nodes:
if preds is not None:
preds = tuple(distributed_concat(p, num_total_examples=self.num_examples(dataloader)) for p in preds)
preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
if label_ids is not None:
label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
elif is_torch_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
if preds is not None:
preds = tuple(xm.mesh_reduce(f"eval_preds_{i}", p, torch.cat) for i, p in enumerate(preds))
preds = nested_xla_mesh_reduce("eval_preds", preds)
if label_ids is not None:
label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)
label_ids = nested_xla_mesh_reduce("eval_label_ids", label_ids, torch.cat)
if eval_losses is not None:
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()

# Finally, turn the aggregated tensors into numpy arrays.
if preds is not None:
preds = tuple(p.cpu().numpy() for p in preds)
if len(preds) == 1:
preds = preds[0]
preds = nested_numpify(preds)
if label_ids is not None:
label_ids = label_ids.cpu().numpy()
label_ids = nested_numpify(label_ids)

if self.compute_metrics is not None and preds is not None and label_ids is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
Expand Down Expand Up @@ -1382,8 +1390,7 @@ def prediction_step(
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
A tuple with the loss, logits and labels (each being optional).
"""
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])

has_labels = all(inputs.get(k) is not None for k in self.args.label_names)
inputs = self._prepare_inputs(inputs)

with torch.no_grad():
Expand All @@ -1402,10 +1409,18 @@ def prediction_step(
if prediction_loss_only:
return (loss, None, None)

labels = inputs.get("labels")
if labels is not None:
labels = labels.detach()
return (loss, tuple(l.detach() for l in logits), labels)
logits = tuple(logit.detach() for logit in logits)
if len(logits) == 1:
logits = logits[0]

if has_labels:
labels = tuple(inputs.get(name).detach() for name in self.args.label_names)
if len(labels) == 1:
labels = labels[0]
else:
labels = None

return (loss, logits, labels)

def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
"""
Expand Down
42 changes: 41 additions & 1 deletion src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from .file_utils import is_tf_available, is_torch_available
from .file_utils import is_tf_available, is_torch_available, is_torch_tpu_available
from .tokenization_utils_base import ExplicitEnum


Expand Down Expand Up @@ -132,9 +132,49 @@ class HPSearchBackend(ExplicitEnum):
}


def nested_concat(tensors, new_tensors, dim=0):
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
if is_torch_available():
assert type(tensors) == type(
new_tensors
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
return torch.cat((tensors, new_tensors), dim=dim)
else:
raise ImportError("Torch must be installed to use `nested_concat`")


def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_numpify(t) for t in tensors)
return tensors.cpu().numpy()


def nested_detach(tensors):
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t) for t in tensors)
return tensors.detach()


def nested_xla_mesh_reduce(tensors, name):
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm

if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
return xm.mesh_reduce(name, tensors, torch.cat)
else:
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")


def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> "torch.Tensor":
if is_torch_available():
try:
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
Expand Down
19 changes: 14 additions & 5 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import os
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from .utils import logging
Expand Down Expand Up @@ -128,6 +128,12 @@ class TrainingArguments:
forward method.
(Note: this behavior is not implemented for :class:`~transformers.TFTrainer` yet.)
label_names (:obj:`List[str]`, `optional`):
The list of keys in your dictionary of inputs that correspond to the labels.
Will eventually default to :obj:`["labels"]` except if the model used is one of the
:obj:`XxxForQuestionAnswering` in which case it will default to
:obj:`["start_positions", "end_positions"]`.
"""

output_dir: str = field(
Expand Down Expand Up @@ -253,13 +259,16 @@ class TrainingArguments:
default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
)

def __post_init__(self):
if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN

remove_unused_columns: Optional[bool] = field(
default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
)
label_names: Optional[List[str]] = field(
default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."}
)

def __post_init__(self):
if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN

@property
def train_batch_size(self) -> int:
Expand Down
33 changes: 25 additions & 8 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,21 @@


class RegressionDataset:
def __init__(self, a=2, b=3, length=64, seed=42):
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
np.random.seed(seed)
self.label_names = ["labels"] if label_names is None else label_names
self.length = length
self.x = np.random.normal(size=(length,)).astype(np.float32)
self.y = a * self.x + b + np.random.normal(scale=0.1, size=(length,))
self.ys = [a * self.x + b + np.random.normal(scale=0.1, size=(length,)) for _ in self.label_names]
self.ys = [y.astype(np.float32) for y in self.ys]

def __len__(self):
return self.length

def __getitem__(self, i):
return {"input_x": self.x[i], "label": self.y[i]}
result = {name: y[i] for name, y in zip(self.label_names, self.ys)}
result["input_x"] = self.x[i]
return result


class AlmostAccuracy:
Expand Down Expand Up @@ -68,16 +72,17 @@ def __init__(self, a=0, b=0, double_output=False):
self.double_output = double_output
self.config = None

def forward(self, input_x=None, labels=None):
def forward(self, input_x=None, labels=None, **kwargs):
y = input_x * self.a + self.b
if labels is None:
return (y, y) if self.double_output else (y,)
loss = torch.nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y)

def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, **kwargs):
train_dataset = RegressionDataset(length=train_len)
eval_dataset = RegressionDataset(length=eval_len)
label_names = kwargs.get("label_names", None)
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
model = RegressionModel(a, b, double_output)
compute_metrics = kwargs.pop("compute_metrics", None)
data_collator = kwargs.pop("data_collator", None)
Expand Down Expand Up @@ -174,7 +179,7 @@ def test_evaluate(self):
trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy())
results = trainer.evaluate()

x, y = trainer.eval_dataset.x, trainer.eval_dataset.y
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
Expand All @@ -185,7 +190,7 @@ def test_evaluate(self):
trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracy())
results = trainer.evaluate()

x, y = trainer.eval_dataset.x, trainer.eval_dataset.y
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
Expand All @@ -212,6 +217,18 @@ def test_predict(self):
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))

# With more than one output/label of the model
trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True, label_names=["labels", "labels_2"])
outputs = trainer.predict(trainer.eval_dataset)
preds = outputs.predictions
labels = outputs.label_ids
x = trainer.eval_dataset.x
self.assertTrue(len(preds), 2)
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))

def test_trainer_with_datasets(self):
np.random.seed(42)
x = np.random.normal(size=(64,)).astype(np.float32)
Expand Down

0 comments on commit 492bb6a

Please sign in to comment.