Skip to content

Commit b4f7b10

Browse files
authored
improve how we setup eval/save strategies and steps (axolotl-ai-cloud#547)
* setup save end eval strategies to be consistent with trainer logic * add comments * better eval handling
1 parent 374a914 commit b4f7b10

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

src/axolotl/utils/trainer.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -567,21 +567,33 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
567567
"sample_packing_efficiency"
568568
] = cfg.sample_packing_eff_est
569569

570-
if cfg.val_set_size == 0:
570+
if cfg.eval_steps and cfg.evaluation_strategy:
571+
# assume if the user set both, they know what they're doing
572+
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
573+
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
574+
elif cfg.val_set_size == 0:
575+
# no eval set, so don't eval
571576
training_arguments_kwargs["evaluation_strategy"] = "no"
577+
elif cfg.evaluation_strategy and cfg.evaluation_strategy in ["epoch", "no"]:
578+
# if explicitly set for epoch, just set, and eval steps don't matter
579+
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
572580
elif cfg.eval_steps:
581+
# steps isn't used w/ epochs
573582
training_arguments_kwargs["evaluation_strategy"] = "steps"
574583
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
575584
else:
576-
# we have an eval set, but no steps defined, use epoch
585+
# we have an eval set, but no steps defined, default to use epoch
577586
training_arguments_kwargs["evaluation_strategy"] = "epoch"
578587

579-
if cfg.save_strategy:
588+
if cfg.save_steps:
589+
# save_steps implies save_strategy of steps
590+
training_arguments_kwargs["save_strategy"] = "steps"
591+
training_arguments_kwargs["save_steps"] = cfg.save_steps
592+
elif cfg.save_strategy:
580593
training_arguments_kwargs["save_strategy"] = cfg.save_strategy
581594
else:
582-
training_arguments_kwargs["save_strategy"] = (
583-
"steps" if cfg.save_steps else "epoch"
584-
)
595+
# default to saving each epoch if not defined
596+
training_arguments_kwargs["save_strategy"] = "epoch"
585597

586598
if cfg.do_bench_eval:
587599
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval

0 commit comments

Comments
 (0)