Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute Loss inside the training step. #686

Merged
101 changes: 81 additions & 20 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import shutil
import sys
import time
import types
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
Expand Down Expand Up @@ -134,6 +135,32 @@
SCALER_NAME = "scaler.pt"


class ModuleWithLoss(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModuleWithLoss as a wrapper for torch.nn.module subclass, can you add a module property so that the unwrap_model could be compatible? I believe ORTModule did the same

def __init__(self, model, args, label_smoother):
super().__init__()
self._original_model = model
self.args = args
# Label smoothing
self.label_smoother = label_smoother

def forward(self, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs):
# The compute_model_plus_loss_internal is assigned once the class is instantiated.
# It should have same signature as Trainer.compute_loss().
# We do this to avoid potential un-synced states if we duplicated compute loss codes .
return self.compute_model_plus_loss_internal(self._original_model, inputs, return_outputs)

@property
def module(self):
"""The original `torch.nn.Module` that this module wraps.
This property provides access to methods and properties on the original module."""

return self._original_model.module

@property
def config(self):
return self._original_model.config


class ORTFeaturesManager:
_TASKS_TO_ORTMODELS = {
"default": ORTModelForFeatureExtraction,
Expand Down Expand Up @@ -279,12 +306,55 @@ def __init__(
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)

# We leverage both training_model and inference_model in conjunction with model.
# _training_model will be wrapped so it will use ORT and will use the overriden functions in ModuleWithLoss.
# _training_model will be storing the default version of the model and will unwrap it in case of eval/test.

# Only Wrap the model if we pass --use_module_with_loss flag.
if args.use_module_with_loss:
self._training_model = self.create_model_with_loss()

self.model = model

self.feature = feature
self.onnx_model_path = onnx_model_path
self.exported_with_loss = False
if self.args.local_rank:
torch.cuda.set_device(self.args.local_rank)

# this method will create a ModuleWithLoss Instance to use if you are passing --use_module_with_loss flag.
# It will help reducing the peak memory usage by computing loss inside training.
def create_model_with_loss(self):
model_with_loss = ModuleWithLoss(self.model, self.args, self.label_smoother)
model_with_loss.compute_model_plus_loss_internal = types.MethodType(Trainer.compute_loss, model_with_loss)

return model_with_loss

# we assume that training_model and inference_model have the same forward signature column.
# self._signature_columns attribute only stores the first-time parsed signature
def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
AdamLouly marked this conversation as resolved.
Show resolved Hide resolved
# Inspect model forward signature to keep only the arguments it accepts.
import inspect

if isinstance(self.model, ModuleWithLoss):
signature = inspect.signature(self.model._original_model.forward)
else:
signature = inspect.signature(self.model.forward)

self._signature_columns = list(signature.parameters.keys())
# Labels may be named label or label_ids, the default data collator handles that.
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))

def compute_loss(self, model_with_loss, inputs, return_outputs=False):
# Run model forward + loss compute.
if isinstance(self.model, ModuleWithLoss):
# ORTModule Does not support the BatchEncoding Type so we have to convert to a dict.
dict_inputs = dict(inputs.items())
return model_with_loss(dict_inputs, return_outputs)
else:
return super().compute_loss(model_with_loss, inputs, return_outputs)

def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
Expand Down Expand Up @@ -313,6 +383,8 @@ def train(
"You need to install `onnxruntime-training` to use `ORTTrainer` for training. Check out "
"https://huggingface.co/docs/optimum/onnxruntime/usage_guides/trainer#install-onnx-runtime."
)
if self.args.use_module_with_loss:
self.model = self._training_model

if resume_from_checkpoint is False:
resume_from_checkpoint = None
Expand Down Expand Up @@ -801,6 +873,8 @@ def evaluate(
dictionary also contains the epoch number which comes from the training state.
"""
# memory metrics - must set up as early as possible
# TODO: We need to enable evaluation using ORT backend.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this TODO is only meant for --loss_in_train flag, right?

self.model = unwrap_model(self.model)
self._memory_tracker.start()

eval_dataloader = self.get_eval_dataloader(eval_dataset)
Expand Down Expand Up @@ -892,6 +966,9 @@ def predict(
- metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
labels).
"""
# TODO: We need to enable evaluation using ORT backend.
self.model = unwrap_model(self.model)

# memory metrics - must set up as early as possible
self._memory_tracker.start()

Expand All @@ -909,10 +986,7 @@ def predict(

try:
output = eval_loop(
test_dataloader,
description="Prediction",
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
except Exception as error:
logger.error(error)
Expand Down Expand Up @@ -1522,13 +1596,7 @@ def _export(
opset = max(opset, 12) # Operators like `nll_loss`are added for opset>=12

output_path = model_path / ONNX_WEIGHTS_NAME
_ = export(
model=model,
config=onnx_config,
opset=opset,
output=output_path,
device=device,
)
_ = export(model=model, config=onnx_config, opset=opset, output=output_path, device=device)

model.config.save_pretrained(model_path)

Expand Down Expand Up @@ -1697,11 +1765,7 @@ def create_optimizer(self):
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
self.optimizer = OSS(params=optimizer_grouped_parameters, optim=optimizer_cls, **optimizer_kwargs)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
Expand Down Expand Up @@ -1731,10 +1795,7 @@ def get_ort_optimizer_cls_and_kwargs(args: ORTTrainingArguments) -> Tuple[Any, A
The training arguments for the training session.
"""
optimizer_kwargs = {"lr": args.learning_rate}
adam_kwargs = {
"betas": (args.adam_beta1, args.adam_beta2),
"eps": args.adam_epsilon,
}
adam_kwargs = {"betas": (args.adam_beta1, args.adam_beta2), "eps": args.adam_epsilon}
if args.optim == ORTOptimizerNames.ADAMW_ORT_FUSED:
try:
from onnxruntime.training.optim import FusedAdam
Expand Down
17 changes: 14 additions & 3 deletions optimum/onnxruntime/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,13 @@ class ORTTrainingArguments(TrainingArguments):
The optimizer to use, including optimizers in Transformers: adamw_hf, adamw_torch, adamw_apex_fused, or adafactor. And optimizers implemented by ONNX Runtime: adamw_ort_fused.
"""

optim: Optional[str] = field(
default="adamw_hf",
metadata={"help": "The optimizer to use."},
optim: Optional[str] = field(default="adamw_hf", metadata={"help": "The optimizer to use."})

use_module_with_loss: Optional[bool] = field(
default=False,
metadata={
"help": "Use ModuleWithLoss Wrapper to compute loss inside the training loop, having this will help save memory for ORTModule Runs."
},
)

# This method will not need to be overriden after the deprecation of `--adafactor` in version 5 of 🤗 Transformers.
Expand Down Expand Up @@ -336,3 +340,10 @@ def __post_init__(self):
f"{self.hub_model_id}).",
FutureWarning,
)
if self.use_module_with_loss is True:
logger.info(
"Using ModuleWithLoss Wrapper."
"loss will be computed during training loop and it will save memory peak "
)
else:
logger.info("Not Using ModuleWithLoss Wrapper.")