Skip to content

Commit 26b9ea4

Browse files
author
Masahiro Tanaka
committed
set optimizer's internal states
1 parent 3099308 commit 26b9ea4

File tree

5 files changed

+33
-5
lines changed

5 files changed

+33
-5
lines changed

deepspeed/checkpoint/universal_checkpoint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
1616

1717
hp_keys = []
1818
for file in os.listdir(folder):
19+
# We expect files named something like "exp_avg.pt", "exp_avg_sq.pt", "fp32.pt"
1920
pattern = r'(.+).pt'
2021
match = re.search(pattern, file)
2122
if match:

deepspeed/runtime/bf16_optimizer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
1919
is_model_parallel_parameter, see_memory_usage, graph_process)
2020

21-
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address
21+
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, map_to_flat_opt_states
2222
from deepspeed.checkpoint import enable_universal_checkpoint
2323
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
2424
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
@@ -457,12 +457,18 @@ def _load_hp_checkpoint_state(self, checkpoint_dir):
457457
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
458458
tp_world_size = self.mpu.get_slice_parallel_world_size()
459459

460-
for i, _ in enumerate(self.optimizer.param_groups):
460+
for i, param_group in enumerate(self.optimizer.param_groups):
461+
# We have an assumption that all params in the same param_group have the same keys
462+
opt_keys = set()
463+
461464
for lp in self.bf16_groups[i]:
462465
if lp._hp_mapping is not None:
463466
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
464467
lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
465468
tp_world_size)
469+
for key in lp._hp_mapping.get_optim_state_keys():
470+
opt_keys.add(key)
471+
map_to_flat_opt_states(param_group['params'][0], self.bf16_groups[i], self.optimizer.state, opt_keys)
466472

467473
def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx):
468474
assert self.immediate_grad_update

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER,
2929
SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE,
3030
BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS)
31-
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state
31+
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, map_to_flat_opt_states
3232
from deepspeed.checkpoint import enable_universal_checkpoint
3333

3434
from deepspeed.utils import groups
@@ -2310,12 +2310,18 @@ def _load_hp_checkpoint_state(self, checkpoint_dir):
23102310
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
23112311
else self.mpu.get_tensor_model_parallel_world_size()
23122312

2313-
for i, _ in enumerate(self.optimizer.param_groups):
2313+
for i, param_group in enumerate(self.optimizer.param_groups):
2314+
# We have an assumption that all params in the same param_group have the same keys
2315+
opt_keys = set()
2316+
23142317
for lp in self.bit16_groups[i]:
23152318
if lp._hp_mapping is not None:
23162319
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
23172320
lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
23182321
tp_world_size)
2322+
for key in lp._hp_mapping.get_optim_state_keys():
2323+
opt_keys.add(key)
2324+
map_to_flat_opt_states(param_group['params'][0], self.bit16_groups[i], self.optimizer.state, opt_keys)
23192325

23202326
def _load_global_state(self, sd):
23212327
self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler)

deepspeed/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .groups import *
1111
from .nvtx import instrument_w_nvtx
1212
# TODO: Move tensor fragment and mixed precision to zero utils
13-
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad
13+
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad, map_to_flat_opt_states
1414
from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
1515
from .tensor_fragment import set_full_hp_param
1616
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state

deepspeed/utils/tensor_fragment.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,21 @@ def get_hp_fragment(self, optim_state_key=None):
5858
return self.get_optim_state_fragment(optim_state_key)
5959

6060

61+
def map_to_flat_opt_states(flat_hp_tensor, lp_tensors, optim_state, opt_keys):
62+
for key in opt_keys:
63+
hp_param = flat_hp_tensor
64+
buffer = torch.zeros_like(hp_param)
65+
66+
for lp in lp_tensors:
67+
if lp._hp_mapping is not None:
68+
hp_fragment_address = lp._hp_mapping.get_hp_fragment_address()
69+
hp_fragment = buffer.narrow(0, hp_fragment_address.start, hp_fragment_address.numel)
70+
hp_fragment.data.copy_(lp._hp_mapping.get_hp_fragment(optim_state_key=key).data)
71+
lp._hp_mapping.hp_fragment = hp_fragment
72+
73+
optim_state[hp_param][key] = buffer
74+
75+
6176
def get_full_hp_param(self, optim_state_key=None):
6277
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
6378
if self._hp_mapping is not None:

0 commit comments

Comments
 (0)