-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow resume_from_checkpoint
to handle auto_find_batch_size
#27568
Changes from all commits
40696e6
dbbc71f
096fd7c
b4a1903
497a44b
ed31b37
780bf72
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1507,6 +1507,10 @@ def train( | |
and not self.is_fsdp_enabled | ||
): | ||
self._load_from_checkpoint(resume_from_checkpoint) | ||
# In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly | ||
state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) | ||
if state.train_batch_size is not None: | ||
self._train_batch_size = state.train_batch_size | ||
|
||
# If model was re-initialized, put it on the right device and update self.model_wrapped | ||
if model_reloaded: | ||
|
@@ -1542,6 +1546,8 @@ def _inner_training_loop( | |
): | ||
self.accelerator.free_memory() | ||
self._train_batch_size = batch_size | ||
if self.args.auto_find_batch_size: | ||
self.state.train_batch_size = self._train_batch_size | ||
muellerzr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") | ||
# Data loader and number of training steps | ||
train_dataloader = self.get_train_dataloader() | ||
|
@@ -1618,6 +1624,7 @@ def _inner_training_loop( | |
|
||
self.state = TrainerState() | ||
self.state.is_hyper_param_search = trial is not None | ||
self.state.train_batch_size = self._train_batch_size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't this mean the batch_size from the state is always loaded even if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we care about if it was called at all and the metadata exists inside the state. |
||
|
||
# Compute absolute values for logging, eval, and save if given as ratio | ||
if args.logging_steps is not None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,7 @@ | |
AutoTokenizer, | ||
IntervalStrategy, | ||
PretrainedConfig, | ||
TrainerCallback, | ||
TrainingArguments, | ||
get_polynomial_decay_schedule_with_warmup, | ||
is_torch_available, | ||
|
@@ -1532,6 +1533,41 @@ def test_auto_batch_size_finder(self): | |
with patch.object(sys, "argv", testargs): | ||
run_glue.main() | ||
|
||
def test_auto_batch_size_with_resume_from_checkpoint(self): | ||
train_dataset = RegressionDataset(length=128) | ||
|
||
config = RegressionModelConfig(a=0, b=2) | ||
model = RegressionRandomPreTrainedModel(config) | ||
|
||
tmp_dir = self.get_auto_remove_tmp_dir() | ||
|
||
class MockCudaOOMCallback(TrainerCallback): | ||
def on_step_end(self, args, state, control, **kwargs): | ||
# simulate OOM on the first step | ||
if state.train_batch_size == 16: | ||
raise RuntimeError("CUDA out of memory.") | ||
|
||
args = RegressionTrainingArguments( | ||
tmp_dir, | ||
do_train=True, | ||
max_steps=2, | ||
save_steps=1, | ||
per_device_train_batch_size=16, | ||
auto_find_batch_size=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think I know enough about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
) | ||
trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()]) | ||
trainer.train() | ||
# After `auto_find_batch_size` is ran we should now be at 8 | ||
self.assertEqual(trainer._train_batch_size, 8) | ||
|
||
# We can then make a new Trainer | ||
trainer = Trainer(model, args, train_dataset=train_dataset) | ||
# Check we are at 16 to start | ||
self.assertEqual(trainer._train_batch_size, 16) | ||
trainer.train(resume_from_checkpoint=True) | ||
# We should be back to 8 again, picking up based upon the last ran Trainer | ||
self.assertEqual(trainer._train_batch_size, 8) | ||
|
||
# regression for this issue: https://github.com/huggingface/transformers/issues/12970 | ||
def test_training_with_resume_from_checkpoint_false(self): | ||
train_dataset = RegressionDataset(length=128) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't necessarily enjoy the fact we load it up here just for one value, but it makes sense to keep this metadata in here rather than in the model metadata as that makes no sense, nor does it make sense to have a whole new dataclass/state/thing for us to store it in either
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, it's not ideal. How large is it / how long does it take to load?
Could we protect it behind the
if state.train_batch_size is not None
branch until there's a need for the state later.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not large at all, it's just got a few dataclass entries in it. However we can certainly protect it to reduce IO time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also that's not really possible, because a user should be able to do
auto_find_batch_size
first, and then not require it again if they're resuming from a checkpoint and that checkpoint stored information about the prior run, for us to load in automatically, so thus it's always needed. This is only happening if we resume from checkpoint which should limit it enough,.