Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow resume_from_checkpoint to handle auto_find_batch_size #27568

Merged
merged 7 commits into from
Dec 8, 2023

Conversation

muellerzr
Copy link
Contributor

What does this PR do?

This PR adds the training batch size as part of the TrainerState. We do this because the TrainerState can be loaded in on resume_from_checkpoint, and so if a user has set auto_find_batch_size to be True, we can keep what that batch size was in there and load it back in if it were saved

Fixes #25956

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amyeroberts

Comment on lines +1533 to +1511
# In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't necessarily enjoy the fact we load it up here just for one value, but it makes sense to keep this metadata in here rather than in the model metadata as that makes no sense, nor does it make sense to have a whole new dataclass/state/thing for us to store it in either

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, it's not ideal. How large is it / how long does it take to load?

Could we protect it behind the if state.train_batch_size is not None branch until there's a need for the state later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not large at all, it's just got a few dataclass entries in it. However we can certainly protect it to reduce IO time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also that's not really possible, because a user should be able to do auto_find_batch_size first, and then not require it again if they're resuming from a checkpoint and that checkpoint stored information about the prior run, for us to load in automatically, so thus it's always needed. This is only happening if we resume from checkpoint which should limit it enough,.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@muellerzr muellerzr requested review from ArthurZucker and amyeroberts and removed request for amyeroberts November 20, 2023 14:49
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
I don't know the context enough, but loading the whole train state seems a bit illogical to me.

  • the auto_find_batch_size does not overwrite the argument's batch size which is why we have to look for it when we resume?
  • not sure I understand the test either, let's try to make it apparent that there are two different values, the input arg and the _train_batch_size that is overwritten by the auto_find_batch_size

# assume that `auto_find_bs` set it to 8, and we were originally at 16
trainer.args.per_device_train_batch_size = 16
trainer.train(resume_from_checkpoint=True)
# We should be back to 16 again
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to 8 no?

tests/trainer/test_trainer.py Outdated Show resolved Hide resolved
@muellerzr
Copy link
Contributor Author

@ArthurZucker agreed that it's a bit overkill. Would it be better to create a new file instead (something like training_metadata.json) instead that only gets made for now when doing something like auto_find_batch_size is enabled?

@ArthurZucker
Copy link
Collaborator

Why don't we just overwrite the arg given from the user?

@muellerzr
Copy link
Contributor Author

@ArthurZucker we still need to store it away somewhere when we do resume_from_checkpoint. The assumption is given a fresh run we don't want to have to run through the iteration loop again to find the right batch size if we've found it once during a prior call. It still needs to be saved somewhere outside on the file system

@ArthurZucker
Copy link
Collaborator

Ah okay, we don't know if the input batch was auto-found or not. Got it. Not sure we want to create a new file for this, fine with loading the state and if we need more meta-data we'll put them there as well I guess!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

Some comments and questions on the state management. I don't know trainer in-depth so I might be misunderstanding how it's meant to behave

max_steps=2,
save_steps=1,
per_device_train_batch_size=8,
auto_find_batch_size=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I know enough about auto_find_batch_size to understand the implication of this test. If auto_find_batch_size is True and per_device_train_batch_size is set - which one takes precedence?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_size exists in the metadata > per_device_train_batch_size > auto_find_batch_size if still OOM

tests/trainer/test_trainer.py Outdated Show resolved Hide resolved
src/transformers/trainer.py Show resolved Hide resolved
@@ -1641,6 +1647,7 @@ def _inner_training_loop(

self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None
self.state.train_batch_size = self._train_batch_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this mean the batch_size from the state is always loaded even if self.args.auto_find_batch_size is False ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we care about if it was called at all and the metadata exists inside the state.

src/transformers/trainer_callback.py Outdated Show resolved Hide resolved
muellerzr and others added 7 commits December 5, 2023 16:08
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@muellerzr muellerzr force-pushed the muellerzr-resume-auto-batch-size branch from 0d798d9 to 780bf72 Compare December 5, 2023 21:08
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, having different variable names like actual train batch size vs user_input_train batch size might help differentiate but it's a nit

@muellerzr muellerzr merged commit 6757ed2 into main Dec 8, 2023
3 checks passed
@muellerzr muellerzr deleted the muellerzr-resume-auto-batch-size branch December 8, 2023 16:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

resume_from_checkpoint may fail with auto_find_batch_size
4 participants