diff --git a/examples/language_model/transformer-xl/train.py b/examples/language_model/transformer-xl/train.py index 78f371a1614fc..79d69fcbadaa8 100644 --- a/examples/language_model/transformer-xl/train.py +++ b/examples/language_model/transformer-xl/train.py @@ -287,8 +287,8 @@ def do_train(args): optimizer.set_lr(curr_lr) elif args.scheduler == 'noam': scheduler.step() - if step_idx >= args.max_step: - break + if step_idx >= args.max_step: + return if args.save_model and rank == 0: model_dir = os.path.join(args.save_model, "step_final")