Skip to content

Commit

Permalink
adding step init check to unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
jomayeri committed Oct 2, 2024
1 parent 15c74cf commit 92bb0ff
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
2 changes: 1 addition & 1 deletion deepspeed/runtime/lr_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def __init__(self,
def get_lr(self):
if self.last_batch_iteration < 0:
logger.warning("Attempting to get learning rate from scheduler before it has started")
return [0.0]
return self.min_lrs
gamma = self._get_gamma()
return [min_lr + (delta_lr * gamma) for min_lr, delta_lr in zip(self.min_lrs, self.delta_lrs)]

Expand Down
12 changes: 11 additions & 1 deletion tests/unit/runtime/test_lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ def _verify_staircase_increase(values, step_size):
(WARMUP_DECAY_LR, {
WARMUP_NUM_STEPS: 10,
TOTAL_NUM_STEPS: 20
}), (ONE_CYCLE, {
}),
(WARMUP_COSINE_LR, {
WARMUP_NUM_STEPS: 10,
TOTAL_NUM_STEPS: 20
}),
(ONE_CYCLE, {
CYCLE_MIN_LR: 0,
CYCLE_MAX_LR: 0.1
}), (LR_RANGE_TEST, {})])
Expand Down Expand Up @@ -71,6 +76,11 @@ def test(self, scheduler_type, params):
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.float)

true_lrs = lr_scheduler.get_lr()
for group, true_lr in zip(model.optimizer.param_groups, true_lrs):
assert group['lr'] == true_lr, f"True lr {true_lr}, optimizer lr {group['lr']}"

for n, batch in enumerate(data_loader):
# get lr before training starts
lr_scheduler.get_lr()
Expand Down

0 comments on commit 92bb0ff

Please sign in to comment.