|
29 | 29 | from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER, |
30 | 30 | SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, |
31 | 31 | BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS) |
32 | | -from deepspeed.utils import link_hp_params |
| 32 | +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state |
33 | 33 | from deepspeed.checkpoint import enable_universal_checkpoint |
34 | 34 |
|
35 | 35 | from deepspeed.utils import groups |
@@ -89,6 +89,12 @@ def _get_padded_tensor(src_tensor, size): |
89 | 89 | return padded_tensor |
90 | 90 |
|
91 | 91 |
|
| 92 | +def _pad_tensor_by_size(src_tensor, pad_size, dtype, device): |
| 93 | + padded_tensor = torch.zeros(src_tensor.numel() + pad_size, dtype=dtype, device=device) |
| 94 | + padded_tensor.data[:src_tensor.numel()].copy_(src_tensor.data) |
| 95 | + return padded_tensor |
| 96 | + |
| 97 | + |
92 | 98 | class DeepSpeedZeroOptimizer(ZeROOptimizer): |
93 | 99 | """ |
94 | 100 | DeepSpeedZeroOptimizer designed to reduce the memory footprint |
@@ -537,6 +543,8 @@ def __init__(self, |
537 | 543 | see_memory_usage(f"After initializing ZeRO optimizer", force=True) |
538 | 544 |
|
539 | 545 | self._link_all_hp_params() |
| 546 | + self._hp_optimizer_states_linked = False |
| 547 | + |
540 | 548 | self._enable_universal_checkpoint() |
541 | 549 | self._param_slice_mappings = self._create_param_mapping() |
542 | 550 |
|
@@ -579,9 +587,15 @@ def _link_all_hp_params(self): |
579 | 587 | param_group_index=i, |
580 | 588 | partition_start=partition_id * partition_size, |
581 | 589 | partition_size=partition_size, |
582 | | - partition_optimizer_state=self.optimizer.state[flat_hp_partition], |
583 | 590 | dp_group=self.real_dp_process_group[i]) |
584 | 591 |
|
| 592 | + def _lazy_init_hp_params_optimizer_state(self): |
| 593 | + if not self._hp_optimizer_states_linked: |
| 594 | + for i, _ in enumerate(self.optimizer.param_groups): |
| 595 | + lazy_init_hp_params_optimizer_state(self.bit16_groups[i], self.single_partition_of_fp32_groups[i], |
| 596 | + self.optimizer.state) |
| 597 | + self._hp_optimizer_states_linked = True |
| 598 | + |
585 | 599 | def is_moe_group(self, group): |
586 | 600 | return 'moe' in group and group['moe'] |
587 | 601 |
|
@@ -665,8 +679,6 @@ def initialize_optimizer_states(self): |
665 | 679 | # which do lazy initialization of the state at the first call to step. |
666 | 680 | if isinstance(self.optimizer, torch.optim.Adagrad): |
667 | 681 | self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults) |
668 | | - else: |
669 | | - self.optimizer.step() |
670 | 682 |
|
671 | 683 | if not self.cpu_offload: |
672 | 684 | for group in self.single_partition_of_fp32_groups: |
@@ -1797,7 +1809,9 @@ def _optimizer_step(self, group_no): |
1797 | 1809 | self.optimizer.step() |
1798 | 1810 | self.optimizer.param_groups = original_param_groups |
1799 | 1811 |
|
1800 | | - # @timeit |
| 1812 | + # We need to link optimizer state after the first step() call |
| 1813 | + self._lazy_init_hp_params_optimizer_state() |
| 1814 | + |
1801 | 1815 | def step(self, closure=None): |
1802 | 1816 | """ |
1803 | 1817 | Not supporting closure. |
@@ -2215,19 +2229,39 @@ def _partition_base_optimizer_state(self, state_key, all_partition_states, group |
2215 | 2229 | # Assume non-tensor states are not partitioned and equal across ranks, so return first one |
2216 | 2230 | return all_partition_states[0] |
2217 | 2231 |
|
2218 | | - def _restore_base_optimizer_state(self, base_optimizer_group_states): |
| 2232 | + def _restore_step_from_elastic_checkpoint(self, all_state_dict): |
| 2233 | + assert BASE_OPTIMIZER_STATE_STEP in all_state_dict[0] |
| 2234 | + assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] |
| 2235 | + for sd in all_state_dict), "State dicts of all partitions must have the same step value" |
| 2236 | + return all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] |
| 2237 | + |
| 2238 | + def _restore_base_optimizer_state(self, base_optimizer_group_states, base_optimizer_state_step, group_paddings): |
2219 | 2239 | if type(base_optimizer_group_states) == dict: |
2220 | 2240 | base_optimizer_group_states = base_optimizer_group_states['state'] |
| 2241 | + |
| 2242 | + saved_keys = base_optimizer_group_states[0].keys() |
| 2243 | + |
2221 | 2244 | for i, group in enumerate(self.optimizer.param_groups): |
2222 | 2245 | p = group['params'][0] |
2223 | | - for key, saved in base_optimizer_group_states[i].items(): |
2224 | | - if torch.is_tensor(self.optimizer.state[p][key]): |
2225 | | - dst_tensor = self.optimizer.state[p][key] |
2226 | | - src_tensor = _get_padded_tensor(saved, dst_tensor.numel()) |
2227 | | - self.optimizer.state[p][key].data.copy_(src_tensor.data) |
| 2246 | + padding = 0 if group_paddings is None else group_paddings[i] |
| 2247 | + for key in saved_keys: |
| 2248 | + saved = base_optimizer_group_states[i][key] |
| 2249 | + |
| 2250 | + if torch.is_tensor(saved): |
| 2251 | + if key in self.optimizer.state[p]: |
| 2252 | + dst_tensor = self.optimizer.state[p][key] |
| 2253 | + src_tensor = _get_padded_tensor(saved, dst_tensor.numel()) |
| 2254 | + self.optimizer.state[p][key].data.copy_(src_tensor.data) |
| 2255 | + else: |
| 2256 | + self.optimizer.state[p][key] = _pad_tensor_by_size( |
| 2257 | + saved, padding, torch.float32, |
| 2258 | + torch.device('cpu') if self.cpu_offload else self.device) |
2228 | 2259 | else: |
2229 | 2260 | self.optimizer.state[p][key] = saved |
2230 | 2261 |
|
| 2262 | + for param_group in self.optimizer.param_groups: |
| 2263 | + param_group['step'] = base_optimizer_state_step |
| 2264 | + |
2231 | 2265 | def get_ep_ranks(self, rank=0, group_name=None): |
2232 | 2266 | from deepspeed.utils import groups |
2233 | 2267 | expert_parallel_size_ = groups._get_expert_parallel_world_size(group_name) |
@@ -2255,15 +2289,8 @@ def _restore_elastic_base_optimizer_state(self, all_state_dict): |
2255 | 2289 | partition_states[key] = self._partition_base_optimizer_state(key, all_partition_states, i) |
2256 | 2290 | base_optimizer_group_states.append(partition_states) |
2257 | 2291 |
|
2258 | | - self._restore_base_optimizer_state(base_optimizer_group_states) |
2259 | | - |
2260 | | - # Restore step |
2261 | | - if BASE_OPTIMIZER_STATE_STEP in all_state_dict[0]: |
2262 | | - assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] |
2263 | | - for sd in all_state_dict), "State dicts of all partitions must have the same step value" |
2264 | | - loaded_param_groups_step = all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] |
2265 | | - for param_group in self.optimizer.param_groups: |
2266 | | - param_group['step'] = loaded_param_groups_step |
| 2292 | + self._restore_base_optimizer_state(base_optimizer_group_states, |
| 2293 | + self._restore_step_from_elastic_checkpoint(all_state_dict), None) |
2267 | 2294 |
|
2268 | 2295 | def load_state_dict(self, |
2269 | 2296 | state_dict_list, |
@@ -2375,7 +2402,9 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l |
2375 | 2402 | self._restore_elastic_base_optimizer_state(state_dict_list) |
2376 | 2403 | else: |
2377 | 2404 | # loading an elastic checkpoint into rigid exec |
2378 | | - self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE]) |
| 2405 | + self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE], |
| 2406 | + current_rank_sd[BASE_OPTIMIZER_STATE_STEP], |
| 2407 | + current_rank_sd[GROUP_PADDINGS]) |
2379 | 2408 |
|
2380 | 2409 | # At this point, the optimizer's references to the model's fp32 parameters are up to date. |
2381 | 2410 | # The optimizer's hyperparameters and internal buffers are also up to date. |
|
0 commit comments