diff --git a/main_dino.py b/main_dino.py index dff726313..46b8cc1c5 100644 --- a/main_dino.py +++ b/main_dino.py @@ -227,9 +227,11 @@ def train_dino(args): fp16_scaler = torch.cuda.amp.GradScaler() # ============ init schedulers ... ============ + # linear scaling rule + lr_scale = (args.batch_size_per_gpu * utils.get_world_size()) / 256. lr_schedule = utils.cosine_scheduler( - args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule - args.min_lr, + args.lr * lr_scale, + args.min_lr * lr_scale, args.epochs, len(data_loader), warmup_epochs=args.warmup_epochs, )