Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ORTTrainer with transformers 4.22.1 release #388

Merged
merged 6 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ As you can see, the process is similar in each case:

### Training

Besides supporting ONNX Runtime inference, 🤗 Optimum also supports ONNX Runtime training, reducing the memory and computations needed during training. This can be achieved by using the class `ORTTrainer`, which possess a similar behavior than the `Trainer` of 🤗 Transformers:
Besides supporting ONNX Runtime inference, 🤗 Optimum also supports training with ONNX Runtime backend. The `ORTTrainer` class possess a similar behavior than the `Trainer` of 🤗 Transformers, but reduces the memory consumption and optimize the computation graphs during training. As a result, you will experience an acceleration and feed larger batch size to your device.
JingyaHuang marked this conversation as resolved.
Show resolved Hide resolved

Replace `Trainer` with `ORTTrainer` to leverage ONNX Runtime on fine-tuning tasks:

```diff
-from transformers import Trainer
Expand All @@ -211,17 +213,14 @@ Besides supporting ONNX Runtime inference, 🤗 Optimum also supports ONNX Runti
compute_metrics=compute_metrics,
tokenizer=tokenizer,
data_collator=default_data_collator,
feature="sequence-classification",
+ feature="sequence-classification",
)

# Step 2: Use ONNX Runtime for training and evalution!🤗
# Step 2: Use ONNX Runtime for training!🤗
train_result = trainer.train()
eval_metrics = trainer.evaluate()
```

By replacing `Trainer` by `ORTTrainer`, you will be able to leverage ONNX Runtime for fine-tuning tasks.

Check out the [`examples`](https://github.com/huggingface/optimum/tree/main/examples) directory for more sophisticated usage.
Check out the [`examples`](https://github.com/huggingface/optimum/tree/main/examples) for more sophisticated usage.

Happy optimizing 🤗!

Expand Down
61 changes: 54 additions & 7 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""
The ORTTrainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task with ONNX Runtime.
"""

import functools
import math
import os
import sys
Expand All @@ -23,7 +23,6 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union

from packaging.version import parse
from tqdm.auto import tqdm


Expand Down Expand Up @@ -57,6 +56,7 @@
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from transformers.onnx import export
from transformers.onnx.features import FeaturesManager
from transformers.pytorch_utils import is_torch_less_than_1_11
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback, TrainerState
Expand All @@ -65,6 +65,7 @@
IterableDatasetShard,
SequentialDistributedSampler,
find_batch_size,
get_module_class_from_name,
get_parameter_names,
nested_concat,
nested_detach,
Expand All @@ -74,6 +75,7 @@
from transformers.trainer_utils import (
EvalLoopOutput,
EvalPrediction,
FSDPOption,
HPSearchBackend,
PredictionOutput,
ShardedDDPOption,
Expand Down Expand Up @@ -237,6 +239,7 @@ def train(
raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
# This might change the seed so needs to run first.
self._hp_search_setup(trial)
self._train_batch_size = self.args.train_batch_size

# Model re-init
model_reloaded = False
Expand Down Expand Up @@ -469,7 +472,7 @@ def _inner_training_loop(
is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
train_dataloader.sampler, RandomSampler
)
if version.parse(torch.__version__) < version.parse("1.11") or not is_random_sampler:
if is_torch_less_than_1_11 or not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler.
# That was before PyTorch 1.11 however...
for _ in train_dataloader:
Expand Down Expand Up @@ -958,13 +961,15 @@ def evaluation_loop_ort(
num_samples = len(eval_dataset)
# The instance check is weird and does not actually check for the type, but whether the dataset has the right
# methods. Therefore we need to make sure it also has the attribute.
elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"):
elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
num_samples = eval_dataset.num_examples
else:
if has_length(dataloader):
num_samples = self.num_examples(dataloader)
else: # both len(dataloader.dataset) and len(dataloader) fail
num_samples = observed_num_examples
if num_samples == 0 and observed_num_examples > 0:
num_samples = observed_num_examples

# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
# samplers has been rounded to a multiple of batch_size, so we truncate.
Expand Down Expand Up @@ -1353,9 +1358,51 @@ def _wrap_model(self, model, training=True, dataloader=None):
)
# Distributed training using PyTorch FSDP
elif self.fsdp is not None:
raise NotImplementedError(
"PyTorch's distrubuted data parallel features are not supported by `ORTTrainer` yet."
)
# PyTorch FSDP!
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy

if FSDPOption.OFFLOAD in self.args.fsdp:
raise NotImplementedError("CPU offload is not compatible with `torch_ort.ORTModule`.")
else:
cpu_offload = CPUOffload(offload_params=False)

auto_wrap_policy = None
if FSDPOption.AUTO_WRAP in self.args.fsdp:
if self.args.fsdp_min_num_params > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_min_num_params
)
elif self.args.fsdp_transformer_layer_cls_to_wrap is not None:
transformer_cls_to_wrap = get_module_class_from_name(
model, self.args.fsdp_transformer_layer_cls_to_wrap
)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls={transformer_cls_to_wrap},
)
mixed_precision_policy = None
dtype = None
if self.args.fp16:
dtype = torch.float16
elif self.args.bf16:
dtype = torch.bfloat16
if dtype is not None:
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
if type(model) != FSDP:
# XXX: Breaking the self.model convention but I see no way around it for now.
self.model = model = FSDP(
model,
sharding_strategy=self.fsdp,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
)
if FSDPOption.OFFLOAD not in self.args.fsdp:
model.to(self.args.device)
elif is_sagemaker_dp_enabled():
raise NotImplementedError(
"Sagemaker's distrubuted data parallel features are not supported by `ORTTrainer` yet."
Expand Down
15 changes: 6 additions & 9 deletions optimum/onnxruntime/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,8 @@ def evaluate(
"""

gen_kwargs = gen_kwargs.copy()
gen_kwargs["max_length"] = (
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
)
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.args.generation_max_length
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
)
Expand Down Expand Up @@ -164,9 +163,8 @@ def predict(
"""

gen_kwargs = gen_kwargs.copy()
gen_kwargs["max_length"] = (
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
)
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.args.generation_max_length
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
)
Expand Down Expand Up @@ -581,9 +579,8 @@ def prediction_step_ort(

# XXX: adapt synced_gpus for fairscale as well
gen_kwargs = self._gen_kwargs.copy()
gen_kwargs["max_length"] = (
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.model.config.max_length
)
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.model.config.max_length
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
)
Expand Down