-
Notifications
You must be signed in to change notification settings - Fork 454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compute Loss inside the training step. #686
Merged
JingyaHuang
merged 18 commits into
huggingface:main
from
AdamLouly:adamlouly/compute_loss_training
Mar 24, 2023
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
69c6f11
improved solutio
AdamLouly c918008
compute loss fix
AdamLouly 40fd135
resolve conflict
AdamLouly 7e96810
esolved comments
4732f2c
removed duplicated code .. used main trainer compute loss
c47fa80
added --loss_in_train flag
040d77e
resolve conflict
c45bd53
resolved comments
31178c8
resolved comments
19cfe04
formatter usng latest black
f268040
add import for code quality
4d8624a
formatter usng latest black
ee6ef10
readding super loss compute
dc8de71
resolv comments
55ad1d2
fix typo
432efe5
solve not exporting onnx models
2b5e57b
dictionary casting , bind method
b6ccb53
trainer fix with ruff
AdamLouly File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
import shutil | ||
import sys | ||
import time | ||
import types | ||
import warnings | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union | ||
|
@@ -134,6 +135,32 @@ | |
SCALER_NAME = "scaler.pt" | ||
|
||
|
||
class ModuleWithLoss(nn.Module): | ||
def __init__(self, model, args, label_smoother): | ||
super().__init__() | ||
self._original_model = model | ||
self.args = args | ||
# Label smoothing | ||
self.label_smoother = label_smoother | ||
|
||
def forward(self, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs): | ||
# The compute_model_plus_loss_internal is assigned once the class is instantiated. | ||
# It should have same signature as Trainer.compute_loss(). | ||
# We do this to avoid potential un-synced states if we duplicated compute loss codes . | ||
return self.compute_model_plus_loss_internal(self._original_model, inputs, return_outputs) | ||
|
||
@property | ||
def module(self): | ||
"""The original `torch.nn.Module` that this module wraps. | ||
This property provides access to methods and properties on the original module.""" | ||
|
||
return self._original_model.module | ||
|
||
@property | ||
def config(self): | ||
return self._original_model.config | ||
|
||
|
||
class ORTFeaturesManager: | ||
_TASKS_TO_ORTMODELS = { | ||
"default": ORTModelForFeatureExtraction, | ||
|
@@ -279,12 +306,55 @@ def __init__( | |
preprocess_logits_for_metrics=preprocess_logits_for_metrics, | ||
) | ||
|
||
# We leverage both training_model and inference_model in conjunction with model. | ||
# _training_model will be wrapped so it will use ORT and will use the overriden functions in ModuleWithLoss. | ||
# _training_model will be storing the default version of the model and will unwrap it in case of eval/test. | ||
|
||
# Only Wrap the model if we pass --use_module_with_loss flag. | ||
if args.use_module_with_loss: | ||
self._training_model = self.create_model_with_loss() | ||
|
||
self.model = model | ||
|
||
self.feature = feature | ||
self.onnx_model_path = onnx_model_path | ||
self.exported_with_loss = False | ||
if self.args.local_rank: | ||
torch.cuda.set_device(self.args.local_rank) | ||
|
||
# this method will create a ModuleWithLoss Instance to use if you are passing --use_module_with_loss flag. | ||
# It will help reducing the peak memory usage by computing loss inside training. | ||
def create_model_with_loss(self): | ||
model_with_loss = ModuleWithLoss(self.model, self.args, self.label_smoother) | ||
model_with_loss.compute_model_plus_loss_internal = types.MethodType(Trainer.compute_loss, model_with_loss) | ||
|
||
return model_with_loss | ||
|
||
# we assume that training_model and inference_model have the same forward signature column. | ||
# self._signature_columns attribute only stores the first-time parsed signature | ||
def _set_signature_columns_if_needed(self): | ||
if self._signature_columns is None: | ||
AdamLouly marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Inspect model forward signature to keep only the arguments it accepts. | ||
import inspect | ||
|
||
if isinstance(self.model, ModuleWithLoss): | ||
signature = inspect.signature(self.model._original_model.forward) | ||
else: | ||
signature = inspect.signature(self.model.forward) | ||
|
||
self._signature_columns = list(signature.parameters.keys()) | ||
# Labels may be named label or label_ids, the default data collator handles that. | ||
self._signature_columns += list(set(["label", "label_ids"] + self.label_names)) | ||
|
||
def compute_loss(self, model_with_loss, inputs, return_outputs=False): | ||
# Run model forward + loss compute. | ||
if isinstance(self.model, ModuleWithLoss): | ||
# ORTModule Does not support the BatchEncoding Type so we have to convert to a dict. | ||
dict_inputs = dict(inputs.items()) | ||
return model_with_loss(dict_inputs, return_outputs) | ||
else: | ||
return super().compute_loss(model_with_loss, inputs, return_outputs) | ||
|
||
def train( | ||
self, | ||
resume_from_checkpoint: Optional[Union[str, bool]] = None, | ||
|
@@ -313,6 +383,8 @@ def train( | |
"You need to install `onnxruntime-training` to use `ORTTrainer` for training. Check out " | ||
"https://huggingface.co/docs/optimum/onnxruntime/usage_guides/trainer#install-onnx-runtime." | ||
) | ||
if self.args.use_module_with_loss: | ||
self.model = self._training_model | ||
|
||
if resume_from_checkpoint is False: | ||
resume_from_checkpoint = None | ||
|
@@ -801,6 +873,8 @@ def evaluate( | |
dictionary also contains the epoch number which comes from the training state. | ||
""" | ||
# memory metrics - must set up as early as possible | ||
# TODO: We need to enable evaluation using ORT backend. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this TODO is only meant for --loss_in_train flag, right? |
||
self.model = unwrap_model(self.model) | ||
self._memory_tracker.start() | ||
|
||
eval_dataloader = self.get_eval_dataloader(eval_dataset) | ||
|
@@ -892,6 +966,9 @@ def predict( | |
- metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained | ||
labels). | ||
""" | ||
# TODO: We need to enable evaluation using ORT backend. | ||
self.model = unwrap_model(self.model) | ||
|
||
# memory metrics - must set up as early as possible | ||
self._memory_tracker.start() | ||
|
||
|
@@ -909,10 +986,7 @@ def predict( | |
|
||
try: | ||
output = eval_loop( | ||
test_dataloader, | ||
description="Prediction", | ||
ignore_keys=ignore_keys, | ||
metric_key_prefix=metric_key_prefix, | ||
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix | ||
) | ||
except Exception as error: | ||
logger.error(error) | ||
|
@@ -1522,13 +1596,7 @@ def _export( | |
opset = max(opset, 12) # Operators like `nll_loss`are added for opset>=12 | ||
|
||
output_path = model_path / ONNX_WEIGHTS_NAME | ||
_ = export( | ||
model=model, | ||
config=onnx_config, | ||
opset=opset, | ||
output=output_path, | ||
device=device, | ||
) | ||
_ = export(model=model, config=onnx_config, opset=opset, output=output_path, device=device) | ||
|
||
model.config.save_pretrained(model_path) | ||
|
||
|
@@ -1697,11 +1765,7 @@ def create_optimizer(self): | |
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) | ||
|
||
if self.sharded_ddp == ShardedDDPOption.SIMPLE: | ||
self.optimizer = OSS( | ||
params=optimizer_grouped_parameters, | ||
optim=optimizer_cls, | ||
**optimizer_kwargs, | ||
) | ||
self.optimizer = OSS(params=optimizer_grouped_parameters, optim=optimizer_cls, **optimizer_kwargs) | ||
else: | ||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) | ||
if optimizer_cls.__name__ == "Adam8bit": | ||
|
@@ -1731,10 +1795,7 @@ def get_ort_optimizer_cls_and_kwargs(args: ORTTrainingArguments) -> Tuple[Any, A | |
The training arguments for the training session. | ||
""" | ||
optimizer_kwargs = {"lr": args.learning_rate} | ||
adam_kwargs = { | ||
"betas": (args.adam_beta1, args.adam_beta2), | ||
"eps": args.adam_epsilon, | ||
} | ||
adam_kwargs = {"betas": (args.adam_beta1, args.adam_beta2), "eps": args.adam_epsilon} | ||
if args.optim == ORTOptimizerNames.ADAMW_ORT_FUSED: | ||
try: | ||
from onnxruntime.training.optim import FusedAdam | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ModuleWithLoss
as a wrapper fortorch.nn.module
subclass, can you add amodule
property so that theunwrap_model
could be compatible? I believeORTModule
did the same