Skip to content

Commit

Permalink
adding lr init in schedulers
Browse files Browse the repository at this point in the history
  • Loading branch information
jomayeri committed Oct 1, 2024
1 parent 8cded57 commit 15c74cf
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions deepspeed/runtime/lr_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,11 @@ def __init__(self,
self.warmup_type = warmup_type
self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps)
self.last_batch_iteration = last_batch_iteration
# Initialize lr in optimizer
if last_batch_iteration == -1:
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

def get_lr(self):
if self.last_batch_iteration < 0:
Expand Down Expand Up @@ -818,6 +823,12 @@ def __init__(self,
logger.warning('total_num_steps {} is less than warmup_num_steps {}'.format(
total_num_steps, warmup_num_steps))
self.org_lrs = [group['lr'] for group in self.optimizer.param_groups]

# Initialize lrs in optimizer groups
if last_batch_iteration == -1:
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

def get_lr_ratio(self):
if self.last_batch_iteration < 0:
Expand Down

0 comments on commit 15c74cf

Please sign in to comment.