Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compute Loss inside the training step. #686
Compute Loss inside the training step. #686
Changes from 7 commits
69c6f11
c918008
40fd135
7e96810
4732f2c
c47fa80
040d77e
c45bd53
31178c8
19cfe04
f268040
4d8624a
ee6ef10
dc8de71
55ad1d2
432efe5
2b5e57b
b6ccb53
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
useless references?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was coming from main when I resolved conflict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but ideally after merge, this two lines should not appear in the diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will remove this to pass code quality check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I don't have push access. Can you use ruff styling with the command
make style
and remove the redundant dependencies? Thx!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did make style and it formatted the trainer.py, but the CI still says it should be formatted.
is there a specific configuration for this?
Seems like CI is using the latest black version every time, so we should always upgrade black before formatting,
it will always format other files that were previously formatted using a different version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @AdamLouly, yes the CI always uses the latest formatting tools. And whenever the team observe a failure of the check code quality CI, we would fix it.
Your previous formatting issue could come from the fact that we recently switched from isort to ruff #760. If you want to be more cautious, you can update your formatter with
pip install -U .[quality]
beforemake style
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ModuleWithLoss
as a wrapper fortorch.nn.module
subclass, can you add amodule
property so that theunwrap_model
could be compatible? I believeORTModule
did the sameThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why there is a case for not using label smoother? In transformers, unless using label smoother, the loss should be already calculated in the forward pass. C.f. gpt2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And with the wrapper, model not using label smoother in the first place shall not have any benefit on memory right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good question; I can help answer it. In short, whenever label smoothing factor argument is given or not, we see improvement on memory. The key reason is the target model ORTModule wraps will have one single loss output after the change.
If label smoothing factor argument is not given, then the CrossEntropyLoss will be done inside the model forward pass. This is true. While there is a minor tricky here:
loss
along withlm_logits
and other intermediate states are returned in the results. If ORTModule wraps and operates on this model, during model exporting, there will be few outputs besidesloss
, those outputs in training phases are not used later, but exporter don't know it will not be used. In ORT training implementation, though those outputs are not used, but we still fill them with zero and use them during the whole backward propagation phase.For this case, if we wrap model+loss together, the final output of model+loss (ORTModule wraps) is just loss.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation @pengwa, that's very clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is not correct. doing this way, self.label_smoother won't be used by the compute_loss. What I suggested earlier is to bind the compute_loss of hf Trainer to ModuleWithLoss. Here is the example:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
B>>main is called
A>>f is called, the prop used is: B's prop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here the
compute_loss()
will useself.hf_trainer.label_smoother
. Although by doing this the compute of loss with label_smoother will be under a certain forward pass and intercepted by onnxruntime. Theself.label_smoother
defined in the init will not be used.It's good that we can reuse the
compute_loss
function, but in terms of code clarity I would prefer to override the forward pass of pretrained model, instead of having Trainer involved.(As discussed internally with transformers team, It would be nice to have a wrapper directly in transformers package to include the compute of loss in forward pass when using
label_smoother
. But let's do that for optimum first, have this PR merged, test it and then when it is mature migrate it to transformers. After that, it would be easier for maintainingORTTrainer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI @JingyaHuang , In this case the code in the Trainer should be maintained if the compute_loss in hf trainer got changed then it should be changed in the forward pass of ModuleWithLoss as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @AdamLouly, sorry for the back and forth. I proposed to rewrite the codes as I was considering opening a PR in Transformers to put the label smoother inside forward. If so we don't need a wrapper in Optimum. But as @pengwa explained, a PR in Transformers won't be enough (we can't limit unnecessary outputs in Transformers for the flexibility reason), so we will always need this wrapper in Optimum.
If so, I agree that we should inherit the
compute_loss()
(as you did before) to ease the maintenance.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe we can pass
self
into ModuleWithLoss int its constructor. Then you can get self.model, self.args, and even self.label_smoothers inside the ModuleWithLoss class.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here is an example:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need to distinguish training model and inference model, if we have the module property with the wrapper, given that we unwrap self.model for inference here:
optimum/optimum/onnxruntime/trainer.py
Line 534 in c45bd53
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have you ever run PyTorch for loss_in_train is enabled, is that working?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this TODO is only meant for --loss_in_train flag, right?