diff --git a/pl_bolts/optimizers/lr_scheduler.py b/pl_bolts/optimizers/lr_scheduler.py index a35582973a..f8afb883f0 100644 --- a/pl_bolts/optimizers/lr_scheduler.py +++ b/pl_bolts/optimizers/lr_scheduler.py @@ -97,12 +97,11 @@ def get_lr(self) -> List[float]: ) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) ] - return [ - (1 + math.cos( + return [(1 + math.cos( math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs) - )) / (1 + math.cos( - math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs) - )) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups + )) / (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs - 1) / ( + self.max_epochs - self.warmup_epochs + ))) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups ] def _get_closed_form_lr(self) -> List[float]: