Skip to content

Commit 2943e6a

Browse files
author
Masahiro Tanaka
committed
resume step in param group
1 parent 26b9ea4 commit 2943e6a

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

deepspeed/runtime/engine.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2785,7 +2785,7 @@ def load_checkpoint(self,
27852785
if self.load_universal_checkpoint():
27862786
self.optimizer.update_lp_params()
27872787
if load_zero_checkpoint:
2788-
self.update_optimizer_step(step=client_states['iteration'] + 1)
2788+
self.update_optimizer_step(step=client_states['iteration'])
27892789

27902790
return load_path, client_states
27912791

@@ -2966,7 +2966,7 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
29662966
def update_optimizer_step(self, step):
29672967

29682968
def set_step(d):
2969-
if isinstance(d['step'], torch.Tensor):
2969+
if 'step' in d and isinstance(d['step'], torch.Tensor):
29702970
d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device)
29712971
else:
29722972
d['step'] = step
@@ -2975,8 +2975,7 @@ def set_step(d):
29752975
base_optimizer = optimizer.optimizer
29762976
state = base_optimizer.state
29772977
for group in optimizer.param_groups:
2978-
if 'step' in group:
2979-
set_step(group)
2978+
set_step(group)
29802979
for p in group['params']:
29812980
if p in state and len(state[p]) > 0 and 'step' in state[p]:
29822981
set_step(state[p])

0 commit comments

Comments
 (0)