Skip to content

Commit

Permalink
Trainer automatically drops unused columns in nlp datasets (huggingfa…
Browse files Browse the repository at this point in the history
…ce#6449)

* Add a classmethod to easily build a Trainer from nlp dataset and metric

* Fix docstrings

* Split train/eval

* Formatting

* Log dropped columns + docs

* Authorize callable activations

* Poc for auto activation

* Be framework-agnostic

* Formatting

* Remove class method

* Remove unnecessary code
  • Loading branch information
sgugger authored and Zigur committed Oct 26, 2020
1 parent 52e8739 commit 06426dc
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 12 deletions.
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
add_start_docstrings,
cached_path,
is_apex_available,
is_nlp_available,
is_psutil_available,
is_py3nvml_available,
is_tf_available,
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@
_tf_available = False # pylint: disable=invalid-name


try:
import nlp # noqa: F401

_nlp_available = True

except ImportError:
_nlp_available = False

try:
from torch.hub import _get_torch_home

Expand Down Expand Up @@ -144,6 +152,10 @@ def is_torch_tpu_available():
return _torch_tpu_available


def is_nlp_available():
return _nlp_available


def is_psutil_available():
return _psutil_available

Expand Down
55 changes: 43 additions & 12 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import logging
import math
import os
Expand All @@ -19,7 +20,7 @@
from tqdm.auto import tqdm, trange

from .data.data_collator import DataCollator, default_data_collator
from .file_utils import is_torch_tpu_available
from .file_utils import is_nlp_available, is_torch_tpu_available
from .integrations import is_comet_available, is_tensorboard_available, is_wandb_available
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
Expand All @@ -41,6 +42,8 @@
_use_native_amp = True
from torch.cuda.amp import autocast

if is_nlp_available():
import nlp

if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -140,19 +143,19 @@ class Trainer:
model (:class:`~transformers.PreTrainedModel`):
The model to train, evaluate or use for predictions.
args (:class:`~transformers.TrainingArguments`):
The arguments to tweak training.
The arguments to tweak for training.
data_collator (:obj:`DataCollator`, `optional`, defaults to :func:`~transformers.default_data_collator`):
The function to use to from a batch from a list of elements of :obj:`train_dataset` or
The function to use to form a batch from a list of elements of :obj:`train_dataset` or
:obj:`eval_dataset`.
train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The dataset to use for training.
The dataset to use for training. If it is an :obj:`nlp.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The dataset to use for evaluation.
The dataset to use for evaluation. If it is an :obj:`nlp.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
The function that will be used to compute metrics at evaluation. Must take a
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
prediction_loss_only (:obj:`bool`, `optional`, defaults to `False`):
When performing evaluation and predictions, only returns the loss.
tb_writer (:obj:`SummaryWriter`, `optional`):
Object to write to TensorBoard.
optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
Expand Down Expand Up @@ -228,11 +231,32 @@ def __init__(
),
FutureWarning,
)

if is_nlp_available():
if isinstance(train_dataset, nlp.Dataset):
self._remove_unused_columns(self.train_dataset, description="training")
if isinstance(eval_dataset, nlp.Dataset):
self._remove_unused_columns(self.eval_dataset, description="evaluation")

self.global_step = None
self.epoch = None
if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler()

def _remove_unused_columns(self, dataset: "nlp.Dataset", description: Optional[str] = None):
# Inspect model forward signature to keep only the arguments it accepts.
signature = inspect.signature(self.model.forward)
signature_columns = list(signature.parameters.keys())
# Labels may be named label or label_ids, the default data collator handles that.
signature_columns += ["label", "label_ids"]
columns = [k for k in signature_columns if k in dataset.column_names]
ignored_columns = list(set(dataset.column_names) - set(signature_columns))
dset_description = "" if description is None else f"in the {description} set "
logger.info(
f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
)
dataset.set_format(columns=columns)

def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
Expand Down Expand Up @@ -287,11 +311,13 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
Args:
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
If provided, will override :obj:`self.eval_dataset`.
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`nlp.Dataset`, columns not
accepted by the ``model.forward()`` method are automatically removed.
"""
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")

elif eval_dataset is not None and is_nlp_available() and isinstance(eval_dataset, nlp.Dataset):
self._remove_unused_columns(eval_dataset, description="evaluation")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
eval_sampler = self._get_eval_sampler(eval_dataset)

Expand All @@ -314,8 +340,11 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
Args:
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The test dataset to use.
The test dataset to use. If it is an :obj:`nlp.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
"""
if is_nlp_available() and isinstance(test_dataset, nlp.Dataset):
self._remove_unused_columns(test_dataset, description="test")
test_sampler = self._get_eval_sampler(test_dataset)

# We use the same batch_size as for eval.
Expand Down Expand Up @@ -903,7 +932,8 @@ def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
Args:
eval_dataset (:obj:`Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`.
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`nlp.Dataset`,
columns not accepted by the ``model.forward()`` method are automatically removed.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
Expand All @@ -929,7 +959,8 @@ def predict(self, test_dataset: Dataset) -> PredictionOutput:
Args:
test_dataset (:obj:`Dataset`):
Dataset to run the predictions on.
Dataset to run the predictions on. If it is an :obj:`nlp.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
Returns:
`NamedTuple`:
Expand Down

0 comments on commit 06426dc

Please sign in to comment.