-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow resume_from_checkpoint
to handle auto_find_batch_size
#27568
Conversation
# In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly | ||
state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't necessarily enjoy the fact we load it up here just for one value, but it makes sense to keep this metadata in here rather than in the model metadata as that makes no sense, nor does it make sense to have a whole new dataclass/state/thing for us to store it in either
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, it's not ideal. How large is it / how long does it take to load?
Could we protect it behind the if state.train_batch_size is not None
branch until there's a need for the state later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not large at all, it's just got a few dataclass entries in it. However we can certainly protect it to reduce IO time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also that's not really possible, because a user should be able to do auto_find_batch_size
first, and then not require it again if they're resuming from a checkpoint and that checkpoint stored information about the prior run, for us to load in automatically, so thus it's always needed. This is only happening if we resume from checkpoint which should limit it enough,.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I don't know the context enough, but loading the whole train state seems a bit illogical to me.
- the
auto_find_batch_size
does not overwrite the argument's batch size which is why we have to look for it when we resume? - not sure I understand the test either, let's try to make it apparent that there are two different values, the input arg and the
_train_batch_size
that is overwritten by theauto_find_batch_size
tests/trainer/test_trainer.py
Outdated
# assume that `auto_find_bs` set it to 8, and we were originally at 16 | ||
trainer.args.per_device_train_batch_size = 16 | ||
trainer.train(resume_from_checkpoint=True) | ||
# We should be back to 16 again |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to 8 no?
@ArthurZucker agreed that it's a bit overkill. Would it be better to create a new file instead (something like |
Why don't we just overwrite the arg given from the user? |
@ArthurZucker we still need to store it away somewhere when we do |
Ah okay, we don't know if the input batch was auto-found or not. Got it. Not sure we want to create a new file for this, fine with loading the state and if we need more meta-data we'll put them there as well I guess! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding!
Some comments and questions on the state management. I don't know trainer in-depth so I might be misunderstanding how it's meant to behave
max_steps=2, | ||
save_steps=1, | ||
per_device_train_batch_size=8, | ||
auto_find_batch_size=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think I know enough about auto_find_batch_size
to understand the implication of this test. If auto_find_batch_size
is True
and per_device_train_batch_size
is set - which one takes precedence?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batch_size
exists in the metadata > per_device_train_batch_size
> auto_find_batch_size
if still OOM
@@ -1641,6 +1647,7 @@ def _inner_training_loop( | |||
|
|||
self.state = TrainerState() | |||
self.state.is_hyper_param_search = trial is not None | |||
self.state.train_batch_size = self._train_batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this mean the batch_size from the state is always loaded even if self.args.auto_find_batch_size
is False
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we care about if it was called at all and the metadata exists inside the state.
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
0d798d9
to
780bf72
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, having different variable names like actual train batch size vs user_input_train batch size might help differentiate but it's a nit
What does this PR do?
This PR adds the training batch size as part of the
TrainerState
. We do this because theTrainerState
can be loaded in onresume_from_checkpoint
, and so if a user has setauto_find_batch_size
to beTrue
, we can keep what that batch size was in there and load it back in if it were savedFixes #25956
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@amyeroberts