Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow resume_from_checkpoint to handle auto_find_batch_size #27568

Merged
merged 7 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,6 +1507,10 @@ def train(
and not self.is_fsdp_enabled
):
self._load_from_checkpoint(resume_from_checkpoint)
# In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
Comment on lines +1510 to +1511
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't necessarily enjoy the fact we load it up here just for one value, but it makes sense to keep this metadata in here rather than in the model metadata as that makes no sense, nor does it make sense to have a whole new dataclass/state/thing for us to store it in either

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, it's not ideal. How large is it / how long does it take to load?

Could we protect it behind the if state.train_batch_size is not None branch until there's a need for the state later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not large at all, it's just got a few dataclass entries in it. However we can certainly protect it to reduce IO time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also that's not really possible, because a user should be able to do auto_find_batch_size first, and then not require it again if they're resuming from a checkpoint and that checkpoint stored information about the prior run, for us to load in automatically, so thus it's always needed. This is only happening if we resume from checkpoint which should limit it enough,.

if state.train_batch_size is not None:
self._train_batch_size = state.train_batch_size

# If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded:
Expand Down Expand Up @@ -1542,6 +1546,8 @@ def _inner_training_loop(
):
self.accelerator.free_memory()
self._train_batch_size = batch_size
if self.args.auto_find_batch_size:
self.state.train_batch_size = self._train_batch_size
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
Expand Down Expand Up @@ -1618,6 +1624,7 @@ def _inner_training_loop(

self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None
self.state.train_batch_size = self._train_batch_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this mean the batch_size from the state is always loaded even if self.args.auto_find_batch_size is False ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we care about if it was called at all and the metadata exists inside the state.


# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class TrainerState:
Run an evaluation every X steps.
save_steps (`int`, *optional*, defaults to 500):
Save checkpoint every X updates steps.
train_batch_size (`int`, *optional*):
The batch size for the training dataloader. Only needed when
`auto_find_batch_size` has been used.
num_input_tokens_seen (`int`, *optional*, defaults to 0):
The number of tokens seen during training (number of input tokens, not the number of prediction tokens).
total_flos (`float`, *optional*, defaults to 0):
Expand Down Expand Up @@ -88,6 +91,7 @@ class TrainerState:
logging_steps: int = 500
eval_steps: int = 500
save_steps: int = 500
train_batch_size: int = None
num_train_epochs: int = 0
num_input_tokens_seen: int = 0
total_flos: float = 0
Expand Down
36 changes: 36 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
AutoTokenizer,
IntervalStrategy,
PretrainedConfig,
TrainerCallback,
TrainingArguments,
get_polynomial_decay_schedule_with_warmup,
is_torch_available,
Expand Down Expand Up @@ -1532,6 +1533,41 @@ def test_auto_batch_size_finder(self):
with patch.object(sys, "argv", testargs):
run_glue.main()

def test_auto_batch_size_with_resume_from_checkpoint(self):
train_dataset = RegressionDataset(length=128)

config = RegressionModelConfig(a=0, b=2)
model = RegressionRandomPreTrainedModel(config)

tmp_dir = self.get_auto_remove_tmp_dir()

class MockCudaOOMCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
# simulate OOM on the first step
if state.train_batch_size == 16:
raise RuntimeError("CUDA out of memory.")

args = RegressionTrainingArguments(
tmp_dir,
do_train=True,
max_steps=2,
save_steps=1,
per_device_train_batch_size=16,
auto_find_batch_size=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I know enough about auto_find_batch_size to understand the implication of this test. If auto_find_batch_size is True and per_device_train_batch_size is set - which one takes precedence?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_size exists in the metadata > per_device_train_batch_size > auto_find_batch_size if still OOM

)
trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()])
trainer.train()
# After `auto_find_batch_size` is ran we should now be at 8
self.assertEqual(trainer._train_batch_size, 8)

# We can then make a new Trainer
trainer = Trainer(model, args, train_dataset=train_dataset)
# Check we are at 16 to start
self.assertEqual(trainer._train_batch_size, 16)
trainer.train(resume_from_checkpoint=True)
# We should be back to 8 again, picking up based upon the last ran Trainer
self.assertEqual(trainer._train_batch_size, 8)

# regression for this issue: https://github.com/huggingface/transformers/issues/12970
def test_training_with_resume_from_checkpoint_false(self):
train_dataset = RegressionDataset(length=128)
Expand Down
Loading