Skip to content

Commit

Permalink
Move fix to the right place
Browse files Browse the repository at this point in the history
  • Loading branch information
nemo committed Jan 24, 2025
1 parent 8801b5b commit 9b9cf6d
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tests/test_torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ def test_causal_lm_training_trainer_compile(self, settings, tokenizer, data, tmp
"output_dir": tmp_dir,
"seed": 0,
}

if isinstance(config, AdaLoraConfig):
train_kwargs["learning_rate"] = 1e-2

training_args = TrainingArguments(
torch_compile=not self.fake_compile,
torch_compile_backend=compile_kwargs.get("torch_compile_backend", None),
Expand All @@ -195,7 +199,6 @@ class OptimizerStepCallback(TrainerCallback):
def on_optimizer_step(self, args, state, control, **kwargs):
model.update_and_allocate(state.global_step)
trainer.add_callback(OptimizerStepCallback())
train_kwargs["learning_rate"] = 1e-2

trainer.train()

Expand Down

0 comments on commit 9b9cf6d

Please sign in to comment.