Skip to content

Commit

Permalink
Rework tests to compare trainer checkpoint args (#29883)
Browse files Browse the repository at this point in the history
* Start rework

* Fix failing test

* Include max

* Update src/transformers/trainer.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
  • Loading branch information
2 people authored and Ita Zaporozhets committed May 14, 2024
1 parent c32c325 commit ea67410
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
23 changes: 14 additions & 9 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,23 +1565,28 @@ def compare_trainer_and_checkpoint_args(self, training_args, trainer_state):
"logging_steps": "logging_steps",
"eval_steps": "eval_steps",
"save_steps": "save_steps",
"per_device_train_batch_size": "train_batch_size",
}

warnings_list = []
has_warning = False
warning_str = "Warning: The following arguments do not match the ones in the `trainer_state.json` within the checkpoint directory: "
for arg_attr, state_attr in attributes_map.items():
arg_value = getattr(training_args, arg_attr, None)
state_value = getattr(trainer_state, state_attr, None)

if arg_value is not None and state_value is not None and arg_value != state_value:
warnings_list.append(
f"Warning: The training argument '{arg_attr}' value ({arg_value}) does not match the trainer state '{state_attr}' value ({state_value}). "
f"This argument will be overridden by the one found in trainer_state.json within the checkpoint directory."
)
warning_str += f"\n\t{arg_attr}: {arg_value} (from args) != {state_value} (from trainer_state.json)"
has_warning = True

# train bs is special as we need to account for multi-GPU
train_bs_args = training_args.per_device_train_batch_size
train_bs_state = trainer_state.train_batch_size // max(1, training_args.n_gpu)

if train_bs_args != train_bs_state:
warning_str += f"\n\tper_device_train_batch_size: {train_bs_args} (from args) != {train_bs_state} (from trainer_state.json)"
has_warning = True

if warnings_list:
for warning in warnings_list:
logger.warning(warning)
if has_warning:
logger.warning_once(warning_str)

def _wrap_model(self, model, training=True, dataloader=None):
if self.args.use_ipex:
Expand Down
10 changes: 4 additions & 6 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2540,16 +2540,14 @@ def test_compare_trainer_and_checkpoint_args_logging(self):
)
checkpoint_trainer.train(resume_from_checkpoint=checkpoint)

self.assertIn("save_steps: 10 (from args) != 5 (from trainer_state.json)", cl.out)

self.assertIn(
"Warning: The training argument 'save_steps' value (10) does not match the trainer state 'save_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
cl.out,
)
self.assertIn(
"Warning: The training argument 'per_device_train_batch_size' value (8) does not match the trainer state 'train_batch_size' value (4). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
"per_device_train_batch_size: 8 (from args) != 4 (from trainer_state.json)",
cl.out,
)
self.assertIn(
"Warning: The training argument 'eval_steps' value (10) does not match the trainer state 'eval_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
"eval_steps: 10 (from args) != 5 (from trainer_state.json)",
cl.out,
)

Expand Down

0 comments on commit ea67410

Please sign in to comment.