diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3293968a3dfd14..6bcf4796f8d565 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1565,23 +1565,28 @@ def compare_trainer_and_checkpoint_args(self, training_args, trainer_state): "logging_steps": "logging_steps", "eval_steps": "eval_steps", "save_steps": "save_steps", - "per_device_train_batch_size": "train_batch_size", } - warnings_list = [] + has_warning = False + warning_str = "Warning: The following arguments do not match the ones in the `trainer_state.json` within the checkpoint directory: " for arg_attr, state_attr in attributes_map.items(): arg_value = getattr(training_args, arg_attr, None) state_value = getattr(trainer_state, state_attr, None) if arg_value is not None and state_value is not None and arg_value != state_value: - warnings_list.append( - f"Warning: The training argument '{arg_attr}' value ({arg_value}) does not match the trainer state '{state_attr}' value ({state_value}). " - f"This argument will be overridden by the one found in trainer_state.json within the checkpoint directory." - ) + warning_str += f"\n\t{arg_attr}: {arg_value} (from args) != {state_value} (from trainer_state.json)" + has_warning = True + + # train bs is special as we need to account for multi-GPU + train_bs_args = training_args.per_device_train_batch_size + train_bs_state = trainer_state.train_batch_size // max(1, training_args.n_gpu) + + if train_bs_args != train_bs_state: + warning_str += f"\n\tper_device_train_batch_size: {train_bs_args} (from args) != {train_bs_state} (from trainer_state.json)" + has_warning = True - if warnings_list: - for warning in warnings_list: - logger.warning(warning) + if has_warning: + logger.warning_once(warning_str) def _wrap_model(self, model, training=True, dataloader=None): if self.args.use_ipex: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 7bbd16737123ab..483022a09e6fab 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -2540,16 +2540,14 @@ def test_compare_trainer_and_checkpoint_args_logging(self): ) checkpoint_trainer.train(resume_from_checkpoint=checkpoint) + self.assertIn("save_steps: 10 (from args) != 5 (from trainer_state.json)", cl.out) + self.assertIn( - "Warning: The training argument 'save_steps' value (10) does not match the trainer state 'save_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.", - cl.out, - ) - self.assertIn( - "Warning: The training argument 'per_device_train_batch_size' value (8) does not match the trainer state 'train_batch_size' value (4). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.", + "per_device_train_batch_size: 8 (from args) != 4 (from trainer_state.json)", cl.out, ) self.assertIn( - "Warning: The training argument 'eval_steps' value (10) does not match the trainer state 'eval_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.", + "eval_steps: 10 (from args) != 5 (from trainer_state.json)", cl.out, )