Skip to content

Commit

Permalink
refactor to set eval_batch_size earlier if unset, so we can warn if m…
Browse files Browse the repository at this point in the history
…ismatched (axolotl-ai-cloud#662)
  • Loading branch information
winglian authored Oct 3, 2023
1 parent 1ce6bdb commit fbb7a3e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ torch_compile_backend: # Optional[str]
# training hyperparameters
gradient_accumulation_steps: 1
micro_batch_size: 2
eval_batch_size: 2
eval_batch_size:
num_epochs: 3
warmup_steps: 100
learning_rate: 0.00003
Expand Down
7 changes: 7 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def normalize_config(cfg):
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
if cfg.eval_batch_size is None:
cfg.eval_batch_size = cfg.micro_batch_size
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0
Expand Down Expand Up @@ -157,6 +159,11 @@ def validate_config(cfg):
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
)
if cfg.eval_batch_size != cfg.micro_batch_size:
LOG.warning(
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
)

if cfg.load_4bit:
raise ValueError("cfg.load_4bit parameter has been deprecated")

Expand Down
4 changes: 1 addition & 3 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,9 +668,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len,
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
if cfg.eval_batch_size is not None
else cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs,
Expand Down

0 comments on commit fbb7a3e

Please sign in to comment.