From 40696e65ff5a0c2fdf91902d49c45499ce0bc932 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Fri, 17 Nov 2023 13:03:30 -0500 Subject: [PATCH 1/7] Fuffill request --- src/transformers/trainer.py | 7 +++++++ src/transformers/trainer_callback.py | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3c9e44201240..d8bfa5c54a65 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1507,6 +1507,10 @@ def train( and not self.is_fsdp_enabled ): self._load_from_checkpoint(resume_from_checkpoint) + # In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly + state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + if state["train_batch_size"] is not None: + self._train_batch_size = state["train_batch_size"] # If model was re-initialized, put it on the right device and update self.model_wrapped if model_reloaded: @@ -1542,6 +1546,8 @@ def _inner_training_loop( ): self.accelerator.free_memory() self._train_batch_size = batch_size + if self.args.auto_find_batch_size: + self.state.train_batch_size = self._train_batch_size logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") # Data loader and number of training steps train_dataloader = self.get_train_dataloader() @@ -1618,6 +1624,7 @@ def _inner_training_loop( self.state = TrainerState() self.state.is_hyper_param_search = trial is not None + self.state.train_batch_size = self._train_batch_size # Compute absolute values for logging, eval, and save if given as ratio if args.logging_steps is not None: diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 13b2dcb6b089..843b5315cd09 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -59,6 +59,9 @@ class TrainerState: Run an evaluation every X steps. save_steps (`int`, *optional*, defaults to 500): Save checkpoint every X updates steps. + training_batch_size (`int`, *optional*): + The batch size for the training dataloader. Only needed when + `auto_find_batch_size` has been used. num_input_tokens_seen (`int`, *optional*, defaults to 0): The number of tokens seen during training (number of input tokens, not the number of prediction tokens). total_flos (`float`, *optional*, defaults to 0): @@ -88,6 +91,7 @@ class TrainerState: logging_steps: int = 500 eval_steps: int = 500 save_steps: int = 500 + train_batch_size: int = None num_train_epochs: int = 0 num_input_tokens_seen: int = 0 total_flos: float = 0 From dbbc71f441bdb388ea55e2daa2b2331108b4cd6c Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Fri, 17 Nov 2023 13:22:11 -0500 Subject: [PATCH 2/7] Add test --- src/transformers/trainer.py | 4 ++-- tests/trainer/test_trainer.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d8bfa5c54a65..925acc5a8104 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1509,8 +1509,8 @@ def train( self._load_from_checkpoint(resume_from_checkpoint) # In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) - if state["train_batch_size"] is not None: - self._train_batch_size = state["train_batch_size"] + if state.train_batch_size is not None: + self._train_batch_size = state.train_batch_size # If model was re-initialized, put it on the right device and update self.model_wrapped if model_reloaded: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 305ccb35d5b0..1459dd666367 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1532,6 +1532,29 @@ def test_auto_batch_size_finder(self): with patch.object(sys, "argv", testargs): run_glue.main() + def test_auto_batch_size_with_resume_from_checkpoint(self): + train_dataset = RegressionDataset(length=128) + + config = RegressionModelConfig(a=0, b=2) + model = RegressionRandomPreTrainedModel(config) + + tmp_dir = self.get_auto_remove_tmp_dir() + args = RegressionTrainingArguments( + tmp_dir, + do_train=True, + max_steps=2, + save_steps=1, + per_device_train_batch_size=16, + auto_find_batch_size=True, + ) + trainer = Trainer(model, args, train_dataset=train_dataset) + trainer.train() + # assume that `auto_find_bs` set it to 8 + trainer.args.per_device_train_batch_size = 8 + trainer.train(resume_from_checkpoint=True) + # We should be back to 16 again + self.assertEqual(trainer._train_batch_size, 16) + # regression for this issue: https://github.com/huggingface/transformers/issues/12970 def test_training_with_resume_from_checkpoint_false(self): train_dataset = RegressionDataset(length=128) From 096fd7cae5ea4b7d2843400767faa53aaaefe680 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Fri, 17 Nov 2023 13:25:26 -0500 Subject: [PATCH 3/7] Better test --- tests/trainer/test_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1459dd666367..3b39119879d8 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1544,16 +1544,16 @@ def test_auto_batch_size_with_resume_from_checkpoint(self): do_train=True, max_steps=2, save_steps=1, - per_device_train_batch_size=16, + per_device_train_batch_size=8, auto_find_batch_size=True, ) trainer = Trainer(model, args, train_dataset=train_dataset) trainer.train() - # assume that `auto_find_bs` set it to 8 - trainer.args.per_device_train_batch_size = 8 + # assume that `auto_find_bs` set it to 8, and we were originally at 16 + trainer.args.per_device_train_batch_size = 16 trainer.train(resume_from_checkpoint=True) # We should be back to 16 again - self.assertEqual(trainer._train_batch_size, 16) + self.assertEqual(trainer._train_batch_size, 8) # regression for this issue: https://github.com/huggingface/transformers/issues/12970 def test_training_with_resume_from_checkpoint_false(self): From b4a19032ad87c9ffc109210876dbd7faa6656bd7 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 5 Dec 2023 09:08:32 -0500 Subject: [PATCH 4/7] Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/trainer_callback.py | 2 +- tests/trainer/test_trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 843b5315cd09..7533d7219c19 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -59,7 +59,7 @@ class TrainerState: Run an evaluation every X steps. save_steps (`int`, *optional*, defaults to 500): Save checkpoint every X updates steps. - training_batch_size (`int`, *optional*): + train_batch_size (`int`, *optional*): The batch size for the training dataloader. Only needed when `auto_find_batch_size` has been used. num_input_tokens_seen (`int`, *optional*, defaults to 0): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 3b39119879d8..ac9b0dba4ef1 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1544,7 +1544,7 @@ def test_auto_batch_size_with_resume_from_checkpoint(self): do_train=True, max_steps=2, save_steps=1, - per_device_train_batch_size=8, + per_device_train_batch_size=16, auto_find_batch_size=True, ) trainer = Trainer(model, args, train_dataset=train_dataset) @@ -1552,7 +1552,7 @@ def test_auto_batch_size_with_resume_from_checkpoint(self): # assume that `auto_find_bs` set it to 8, and we were originally at 16 trainer.args.per_device_train_batch_size = 16 trainer.train(resume_from_checkpoint=True) - # We should be back to 16 again + # We should be back to 8 again self.assertEqual(trainer._train_batch_size, 8) # regression for this issue: https://github.com/huggingface/transformers/issues/12970 From 497a44b54c0da1c2f74358e42bf892632c27c74b Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Tue, 5 Dec 2023 09:21:21 -0500 Subject: [PATCH 5/7] Better test --- tests/trainer/test_trainer.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ac9b0dba4ef1..6e67787147e3 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -38,6 +38,7 @@ AutoTokenizer, IntervalStrategy, PretrainedConfig, + TrainerCallback, TrainingArguments, get_polynomial_decay_schedule_with_warmup, is_torch_available, @@ -1539,6 +1540,13 @@ def test_auto_batch_size_with_resume_from_checkpoint(self): model = RegressionRandomPreTrainedModel(config) tmp_dir = self.get_auto_remove_tmp_dir() + + class MockCudaOOMCallback(TrainerCallback): + def on_step_end(self, args, state, control, **kwargs): + # simulate OOM on the first step + if state.train_batch_size == 16: + raise RuntimeError("CUDA out of memory.") + args = RegressionTrainingArguments( tmp_dir, do_train=True, @@ -1547,10 +1555,13 @@ def test_auto_batch_size_with_resume_from_checkpoint(self): per_device_train_batch_size=16, auto_find_batch_size=True, ) - trainer = Trainer(model, args, train_dataset=train_dataset) + trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()]) trainer.train() - # assume that `auto_find_bs` set it to 8, and we were originally at 16 - trainer.args.per_device_train_batch_size = 16 + # After `auto_find_batch_size` is ran we should now be at 8 + self.assertEqual(trainer._train_batch_size, 8) + + # We can then make a new Trainer + trainer = Trainer(model, args, train_dataset=train_dataset) trainer.train(resume_from_checkpoint=True) # We should be back to 8 again self.assertEqual(trainer._train_batch_size, 8) From ed31b3757f2df35b8e9de76f8aa42422b24af5a2 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Tue, 5 Dec 2023 09:21:51 -0500 Subject: [PATCH 6/7] Better test --- tests/trainer/test_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 6e67787147e3..ebc15c6fde1c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1562,6 +1562,8 @@ def on_step_end(self, args, state, control, **kwargs): # We can then make a new Trainer trainer = Trainer(model, args, train_dataset=train_dataset) + # Check we are at 16 to start + self.assertEqual(trainer._train_batch_size, 16) trainer.train(resume_from_checkpoint=True) # We should be back to 8 again self.assertEqual(trainer._train_batch_size, 8) From 780bf725adb7211b8ce9531f3bce3d4ce9dfe62d Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Tue, 5 Dec 2023 09:22:48 -0500 Subject: [PATCH 7/7] MOre comments --- tests/trainer/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ebc15c6fde1c..188e79e2057e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1565,7 +1565,7 @@ def on_step_end(self, args, state, control, **kwargs): # Check we are at 16 to start self.assertEqual(trainer._train_batch_size, 16) trainer.train(resume_from_checkpoint=True) - # We should be back to 8 again + # We should be back to 8 again, picking up based upon the last ran Trainer self.assertEqual(trainer._train_batch_size, 8) # regression for this issue: https://github.com/huggingface/transformers/issues/12970