Skip to content

Commit 98bd0e4

Browse files
tohtanadbyoung18
authored andcommitted
Fix loading a universal checkpoint (deepspeedai#5263)
This PR fixes the following two points regarding checkpoint loading. - Load optimizer states With [this PR](deepspeedai#5104), we removed optimizer's `step()` on initialization. This made the DS's parameter update match with PyTorch's normal behavior. However, we don't have keys in optimizer states any more when we load a checkpoint. For legacy/elastic checkpoints, the PR changed the checkpoint loaders to create keys and buffers on loading. However, the loader for universal checkpoints still relies on keys in optimizer states. As the result, loading a universal checkpoint fails. This PR fixes the loader to find optimizer state keys from a given checkpoint. - Resume step count deepspeedai@2943e6a The checkpoint loader for a universal checkpoint resumes step count for optimizer only when the param group already has `step`. But some optimizers creates the key `step` in a param group at the first call of `step()` (e.g. Apex [Fused Adam](https://github.com/NVIDIA/apex/blob/810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c/apex/optimizers/fused_adam.py#L154). In this case, the step count is not restored. This PR changes this behavior to always set step count in a param group. This PR also stop incrementing the step count when loading. I didn't see why we need to increment the step count for my small example, but we may need a discussion to consider various cases.
1 parent d57e635 commit 98bd0e4

File tree

6 files changed

+58
-21
lines changed

6 files changed

+58
-21
lines changed

deepspeed/checkpoint/universal_checkpoint.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,26 @@
44
# DeepSpeed Team
55

66
import os
7+
import re
78
import torch
89
import types
910
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS)
1011

1112

1213
def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
1314
hp_mapping = self._hp_mapping
14-
optim_state_keys = hp_mapping.get_optim_state_keys()
15-
hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys
16-
#print(f'{hp_keys=}')
17-
checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys}
18-
for file in checkpoint_files.values():
19-
assert os.path.isfile(file), f'{file} is not a valid file'
15+
hp_mapping.optim_fragment = {}
16+
17+
hp_keys = []
18+
for file in os.listdir(folder):
19+
# We expect files named something like "exp_avg.pt", "exp_avg_sq.pt", "fp32.pt"
20+
pattern = r'(.+).pt'
21+
match = re.search(pattern, file)
22+
if match:
23+
hp_keys.append(match.group(1))
2024

2125
for key in hp_keys:
22-
ckpt_file = checkpoint_files[key]
26+
ckpt_file = os.path.join(folder, f"{key}.pt")
2327
ckpt_dict = torch.load(ckpt_file)
2428
full_hp_param = ckpt_dict[PARAM]
2529

@@ -62,7 +66,6 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
6266

6367
assert full_param_numel == tp_world_size * tp_slice_numel, \
6468
f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}'
65-
dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment(key)
6669

6770
# print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
6871
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")
@@ -84,13 +87,21 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
8487

8588
lp_frag_address = hp_mapping.lp_fragment_address
8689
tp_hp_fragment = tp_hp_slice.narrow(0, lp_frag_address.start, lp_frag_address.numel)
87-
assert dst_tensor.numel() == lp_frag_address.numel, \
88-
f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'
8990

9091
# print(f"{key} SHAPE: {tp_hp_slice.shape=}")
9192
# print(f"{key} SHAPE: {dst_tensor.shape=}")
9293
# print(f"{key} SHAPE: {tp_hp_fragment.shape=}")
93-
dst_tensor.data.copy_(tp_hp_fragment.data)
94+
95+
if key == FP32_WEIGHT_KEY:
96+
dst_tensor = hp_mapping.get_hp_fragment()
97+
assert dst_tensor.numel() == lp_frag_address.numel, \
98+
f'Load checkpoint {key} dst numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'
99+
dst_tensor.data.copy_(tp_hp_fragment.data)
100+
else:
101+
assert tp_hp_fragment.numel() == lp_frag_address.numel, \
102+
f'Load checkpoint {key} dst numel {tp_hp_fragment.numel()} != src numel {lp_frag_address.numel}'
103+
104+
hp_mapping.optim_fragment[key] = tp_hp_fragment.clone().detach()
94105

95106

96107
def enable_universal_checkpoint(param_list):

deepspeed/runtime/bf16_optimizer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
1919
is_model_parallel_parameter, see_memory_usage, graph_process)
2020

21-
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address
21+
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, map_to_flat_opt_states
2222
from deepspeed.checkpoint import enable_universal_checkpoint
2323
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
2424
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
@@ -457,12 +457,18 @@ def _load_hp_checkpoint_state(self, checkpoint_dir):
457457
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
458458
tp_world_size = self.mpu.get_slice_parallel_world_size()
459459

460-
for i, _ in enumerate(self.optimizer.param_groups):
460+
for i, param_group in enumerate(self.optimizer.param_groups):
461+
# We have an assumption that all params in the same param_group have the same keys
462+
opt_keys = set()
463+
461464
for lp in self.bf16_groups[i]:
462465
if lp._hp_mapping is not None:
463466
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
464467
lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
465468
tp_world_size)
469+
for key in lp._hp_mapping.get_optim_state_keys():
470+
opt_keys.add(key)
471+
map_to_flat_opt_states(param_group['params'][0], self.bf16_groups[i], self.optimizer.state, opt_keys)
466472

467473
def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx):
468474
assert self.immediate_grad_update

deepspeed/runtime/engine.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2785,7 +2785,7 @@ def load_checkpoint(self,
27852785
if self.load_universal_checkpoint():
27862786
self.optimizer.update_lp_params()
27872787
if load_zero_checkpoint:
2788-
self.update_optimizer_step(step=client_states['iteration'] + 1)
2788+
self.update_optimizer_step(step=client_states['iteration'])
27892789

27902790
return load_path, client_states
27912791

@@ -2966,7 +2966,7 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
29662966
def update_optimizer_step(self, step):
29672967

29682968
def set_step(d):
2969-
if isinstance(d['step'], torch.Tensor):
2969+
if 'step' in d and isinstance(d['step'], torch.Tensor):
29702970
d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device)
29712971
else:
29722972
d['step'] = step
@@ -2975,10 +2975,9 @@ def set_step(d):
29752975
base_optimizer = optimizer.optimizer
29762976
state = base_optimizer.state
29772977
for group in optimizer.param_groups:
2978-
if 'step' in group:
2979-
set_step(group)
2978+
set_step(group)
29802979
for p in group['params']:
2981-
if p in state and len(state[p]) > 0 and 'step' in state[p]:
2980+
if p in state and len(state[p]) > 0:
29822981
set_step(state[p])
29832982

29842983
def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode):

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER,
2929
SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE,
3030
BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS)
31-
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state
31+
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, map_to_flat_opt_states
3232
from deepspeed.checkpoint import enable_universal_checkpoint
3333

3434
from deepspeed.utils import groups
@@ -2310,12 +2310,18 @@ def _load_hp_checkpoint_state(self, checkpoint_dir):
23102310
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
23112311
else self.mpu.get_tensor_model_parallel_world_size()
23122312

2313-
for i, _ in enumerate(self.optimizer.param_groups):
2313+
for i, param_group in enumerate(self.optimizer.param_groups):
2314+
# We have an assumption that all params in the same param_group have the same keys
2315+
opt_keys = set()
2316+
23142317
for lp in self.bit16_groups[i]:
23152318
if lp._hp_mapping is not None:
23162319
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
23172320
lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
23182321
tp_world_size)
2322+
for key in lp._hp_mapping.get_optim_state_keys():
2323+
opt_keys.add(key)
2324+
map_to_flat_opt_states(param_group['params'][0], self.bit16_groups[i], self.optimizer.state, opt_keys)
23192325

23202326
def _load_global_state(self, sd):
23212327
self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler)

deepspeed/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .groups import *
1111
from .nvtx import instrument_w_nvtx
1212
# TODO: Move tensor fragment and mixed precision to zero utils
13-
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad
13+
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad, map_to_flat_opt_states
1414
from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
1515
from .tensor_fragment import set_full_hp_param
1616
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state

deepspeed/utils/tensor_fragment.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,21 @@ def get_hp_fragment(self, optim_state_key=None):
5858
return self.get_optim_state_fragment(optim_state_key)
5959

6060

61+
def map_to_flat_opt_states(flat_hp_tensor, lp_tensors, optim_state, opt_keys):
62+
for key in opt_keys:
63+
hp_param = flat_hp_tensor
64+
buffer = torch.zeros_like(hp_param)
65+
66+
for lp in lp_tensors:
67+
if lp._hp_mapping is not None:
68+
hp_fragment_address = lp._hp_mapping.get_hp_fragment_address()
69+
hp_fragment = buffer.narrow(0, hp_fragment_address.start, hp_fragment_address.numel)
70+
hp_fragment.data.copy_(lp._hp_mapping.get_hp_fragment(optim_state_key=key).data)
71+
lp._hp_mapping.hp_fragment = hp_fragment
72+
73+
optim_state[hp_param][key] = buffer
74+
75+
6176
def get_full_hp_param(self, optim_state_key=None):
6277
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
6378
if self._hp_mapping is not None:

0 commit comments

Comments
 (0)