Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute Loss inside the training step. #686

Merged
89 changes: 75 additions & 14 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
from tqdm.auto import tqdm
AdamLouly marked this conversation as resolved.
Show resolved Hide resolved


# Integrations must be imported before ML frameworks:
from transformers.integrations import hp_params, is_fairscale_available # isort: split
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useless references?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was coming from main when I resolved conflict

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but ideally after merge, this two lines should not appear in the diff.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove this to pass code quality check.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I don't have push access. Can you use ruff styling with the command make style and remove the redundant dependencies? Thx!

Copy link
Contributor Author

@AdamLouly AdamLouly Mar 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I don't have push access. Can you use ruff styling with the command make style and remove the redundant dependencies? Thx!

I did make style and it formatted the trainer.py, but the CI still says it should be formatted.
is there a specific configuration for this?

Seems like CI is using the latest black version every time, so we should always upgrade black before formatting,
it will always format other files that were previously formatted using a different version.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @AdamLouly, yes the CI always uses the latest formatting tools. And whenever the team observe a failure of the check code quality CI, we would fix it.

Your previous formatting issue could come from the fact that we recently switched from isort to ruff #760. If you want to be more cautious, you can update your formatter with pip install -U .[quality] before make style.




# Integrations must be imported before ML frameworks:
Expand All @@ -45,6 +51,7 @@
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
from transformers.dependency_versions_check import dep_version_check
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
AdamLouly marked this conversation as resolved.
Show resolved Hide resolved
from transformers.file_utils import (
is_apex_available,
is_sagemaker_dp_enabled,
Expand Down Expand Up @@ -134,6 +141,28 @@
SCALER_NAME = "scaler.pt"


class ModuleWithLoss(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModuleWithLoss as a wrapper for torch.nn.module subclass, can you add a module property so that the unwrap_model could be compatible? I believe ORTModule did the same

def __init__(self, model, args) -> None:
super().__init__()
self._original_model = model
self.args = args
self.hf_trainer = Trainer(model)
AdamLouly marked this conversation as resolved.
Show resolved Hide resolved
# Label smoothing
if self.args.label_smoothing_factor != 0:
from transformers.trainer_pt_utils import LabelSmoother

self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
else:
self.label_smoother = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why there is a case for not using label smoother? In transformers, unless using label smoother, the loss should be already calculated in the forward pass. C.f. gpt2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And with the wrapper, model not using label smoother in the first place shall not have any benefit on memory right?

Copy link

@pengwa pengwa Mar 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And with the wrapper, model not using label smoother in the first place shall not have any benefit on memory right?

This is a good question; I can help answer it. In short, whenever label smoothing factor argument is given or not, we see improvement on memory. The key reason is the target model ORTModule wraps will have one single loss output after the change.

If label smoothing factor argument is not given, then the CrossEntropyLoss will be done inside the model forward pass. This is true. While there is a minor tricky here: loss along with lm_logits and other intermediate states are returned in the results. If ORTModule wraps and operates on this model, during model exporting, there will be few outputs besides loss, those outputs in training phases are not used later, but exporter don't know it will not be used. In ORT training implementation, though those outputs are not used, but we still fill them with zero and use them during the whole backward propagation phase.

For this case, if we wrap model+loss together, the final output of model+loss (ORTModule wraps) is just loss.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation @pengwa, that's very clear.


def forward(self, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs):
return self.hf_trainer.compute_loss(self._original_model, inputs, return_outputs=False)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not correct. doing this way, self.label_smoother won't be used by the compute_loss. What I suggested earlier is to bind the compute_loss of hf Trainer to ModuleWithLoss. Here is the example:


class A:
    def __init__(self) -> None:
        self.prop = "A's prop"

    def f(self, x: int) -> int:
        print("A>>f is called, the prop used is: ", self.prop)
        return x + 1

class B:
    def __init__(self) -> None:
        self.prop = "B's prop"

    def main(self, x: int) -> int:
        print("B>>main is called")
        self.f(x)
        return x + 2

    def f(self, x: int) -> int:
        raise NotImplementedError()


b_instance = B()
import types

b_instance.f = types.MethodType(getattr(A,'f'), b_instance)

b_instance.main(1)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

B>>main is called
A>>f is called, the prop used is: B's prop

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the compute_loss() will use self.hf_trainer.label_smoother. Although by doing this the compute of loss with label_smoother will be under a certain forward pass and intercepted by onnxruntime. The self.label_smoother defined in the init will not be used.

It's good that we can reuse the compute_loss function, but in terms of code clarity I would prefer to override the forward pass of pretrained model, instead of having Trainer involved.

(As discussed internally with transformers team, It would be nice to have a wrapper directly in transformers package to include the compute of loss in forward pass when using label_smoother. But let's do that for optimum first, have this PR merged, test it and then when it is mature migrate it to transformers. After that, it would be easier for maintaining ORTTrainer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @JingyaHuang , In this case the code in the Trainer should be maintained if the compute_loss in hf trainer got changed then it should be changed in the forward pass of ModuleWithLoss as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @AdamLouly, sorry for the back and forth. I proposed to rewrite the codes as I was considering opening a PR in Transformers to put the label smoother inside forward. If so we don't need a wrapper in Optimum. But as @pengwa explained, a PR in Transformers won't be enough (we can't limit unnecessary outputs in Transformers for the flexibility reason), so we will always need this wrapper in Optimum.

If so, I agree that we should inherit the compute_loss()(as you did before) to ease the maintenance.


@property
def config(self):
return self._original_model.config


class ORTFeaturesManager:
_TASKS_TO_ORTMODELS = {
"default": ORTModelForFeatureExtraction,
Expand Down Expand Up @@ -279,12 +308,48 @@ def __init__(
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)

# We leverage both training_model and inference_model in conjunction with model.
# _training_model will be wrapped so it will use ORT and will use the overriden functions in ModuleWithLoss.
# _inferencing_model will be storing the default version of the model and we will switch to it in case of eval/test.

# Only Wrap the model if we pass --loss_in_train flag.
if args.loss_in_train:
self._training_model = ModuleWithLoss(model, args)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can pass self into ModuleWithLoss int its constructor. Then you can get self.model, self.args, and even self.label_smoothers inside the ModuleWithLoss class.

Copy link

@pengwa pengwa Feb 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is an example:

import types

def __init__(...):
    ....
    if args.loss_in_train:
       self._training_model = self.create_model_with_loss()
    ...

def create_model_with_loss(self):
    class ModuleWithLoss(nn.Module):
        def __init__(self, model, args, label_smoother):
            super().__init__()
            self._original_model = model
            self.args = args
            self.label_smoother = label_smoother

        def forward(self, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs):
            # The compute_model_plus_loss_internal is assigned once the class is instantiated.
            # It should have same signature as Trainer.compute_loss().
            # We do this to avoid potential un-synced states if we duplicated compute loss codes .
            return self.compute_model_plus_loss_internal(self._original_model, inputs, return_outputs)

        @property
        def config(self):
            return self._original_model.config

    model_with_loss = ModuleWithLoss(self.model, self.args, self.label_smoother)
    model_with_loss.compute_model_plus_loss_internal = types.MethodType(self.super().compute_loss, model_with_loss)

    return model_with_loss

else:
self._training_model = model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to distinguish training model and inference model, if we have the module property with the wrapper, given that we unwrap self.model for inference here:

self.model = unwrap_model(deepspeed_engine)


self.model = model
self._inferencing_model = model
self.feature = feature
self.onnx_model_path = onnx_model_path
self.exported_with_loss = False
if self.args.local_rank:
torch.cuda.set_device(self.args.local_rank)

# we assume that training_model and inference_model have the same forward signature column.
# self._signature_columns attribute only stores the first-time parsed signature
def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
AdamLouly marked this conversation as resolved.
Show resolved Hide resolved
# Inspect model forward signature to keep only the arguments it accepts.
import inspect

if isinstance(self.model, ModuleWithLoss):
signature = inspect.signature(self.model._original_model.forward)
else:
signature = inspect.signature(self.model.forward)

self._signature_columns = list(signature.parameters.keys())
# Labels may be named label or label_ids, the default data collator handles that.
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))

def compute_loss(self, model_with_loss, inputs, return_outputs=False):
# Run model forward + loss compute.
if self.args.loss_in_train and self.model == self._training_model:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you ever run PyTorch for loss_in_train is enabled, is that working?

outputs = model_with_loss(inputs, return_outputs)
return outputs
else:
return super().compute_loss(self.model, inputs, return_outputs)

def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
Expand Down Expand Up @@ -314,6 +379,8 @@ def train(
"https://huggingface.co/docs/optimum/onnxruntime/usage_guides/trainer#install-onnx-runtime."
)

self.model = self._training_model

if resume_from_checkpoint is False:
resume_from_checkpoint = None

Expand Down Expand Up @@ -592,7 +659,6 @@ def _inner_training_loop(
# Otherwise we need to call the whooooole sampler cause there is some random operation added
# AT THE VERY END!
_ = list(train_dataloader.sampler)

for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
Expand Down Expand Up @@ -801,6 +867,8 @@ def evaluate(
dictionary also contains the epoch number which comes from the training state.
"""
# memory metrics - must set up as early as possible
# TODO: We need to enable evaluation using ORT backend.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this TODO is only meant for --loss_in_train flag, right?

self.model = self._inferencing_model
AdamLouly marked this conversation as resolved.
Show resolved Hide resolved
self._memory_tracker.start()

eval_dataloader = self.get_eval_dataloader(eval_dataset)
Expand Down Expand Up @@ -892,6 +960,9 @@ def predict(
- metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
labels).
"""
# TODO: We need to enable evaluation using ORT backend.
self.model = self._inferencing_model

# memory metrics - must set up as early as possible
self._memory_tracker.start()

Expand All @@ -909,10 +980,7 @@ def predict(

try:
output = eval_loop(
test_dataloader,
description="Prediction",
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
except Exception as error:
logger.error(error)
Expand Down Expand Up @@ -1697,11 +1765,7 @@ def create_optimizer(self):
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
self.optimizer = OSS(params=optimizer_grouped_parameters, optim=optimizer_cls, **optimizer_kwargs)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
Expand Down Expand Up @@ -1731,10 +1795,7 @@ def get_ort_optimizer_cls_and_kwargs(args: ORTTrainingArguments) -> Tuple[Any, A
The training arguments for the training session.
"""
optimizer_kwargs = {"lr": args.learning_rate}
adam_kwargs = {
"betas": (args.adam_beta1, args.adam_beta2),
"eps": args.adam_epsilon,
}
adam_kwargs = {"betas": (args.adam_beta1, args.adam_beta2), "eps": args.adam_epsilon}
if args.optim == ORTOptimizerNames.ADAMW_ORT_FUSED:
try:
from onnxruntime.training.optim import FusedAdam
Expand Down
14 changes: 14 additions & 0 deletions optimum/onnxruntime/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ class ORTTrainingArguments(TrainingArguments):
metadata={"help": "The optimizer to use."},
)

loss_in_train: Optional[bool] = field(
default=False,
metadata={"help": "Use ModuleWithLoss Wrapper to compute loss inside the training loop."},
AdamLouly marked this conversation as resolved.
Show resolved Hide resolved
)

# This method will not need to be overriden after the deprecation of `--adafactor` in version 5 of 🤗 Transformers.
def __post_init__(self):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
Expand Down Expand Up @@ -336,3 +341,12 @@ def __post_init__(self):
f"{self.hub_model_id}).",
FutureWarning,
)
if self.loss_in_train is True:
logger.info(
"Using ModuleWithLoss Wrapper."
"loss will be computed during training loop and it will save memory peak "
)
else:
logger.info(
"Not Using ModuleWithLoss Wrapper."
)