Skip to content

Commit 12057d6

Browse files
tohtanatjruwase
authored andcommitted
Remove optimizer step on initialization (deepspeedai#5104)
All ZeRO 1/2/3 stages call the optimizer's `step()` on its initialization. This increments a counter in the optimizer and produces a different result in parameter update with the normal usage of PyTorch. This PR eliminates `step()` in the initialization and lazily configures some internal states (linking *hp_params*) after the first `step()` call. --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
1 parent e4f3544 commit 12057d6

File tree

8 files changed

+141
-96
lines changed

8 files changed

+141
-96
lines changed

deepspeed/runtime/bf16_optimizer.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
1919
is_model_parallel_parameter, see_memory_usage, graph_process)
2020

21-
from deepspeed.utils import link_hp_params, fragment_address
21+
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address
2222
from deepspeed.checkpoint import enable_universal_checkpoint
2323
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
2424
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
@@ -165,6 +165,7 @@ def _setup_for_real_optimizer(self):
165165

166166
# Need optimizer states initialized before linking lp to optimizer state
167167
self._link_all_hp_params()
168+
self._hp_optimizer_states_linked = False
168169
self._enable_universal_checkpoint()
169170
self._param_slice_mappings = self._create_param_mapping()
170171

@@ -199,9 +200,15 @@ def _link_all_hp_params(self):
199200
param_group_index=i,
200201
partition_start=partition_id * partition_size,
201202
partition_size=partition_size,
202-
partition_optimizer_state=self.optimizer.state[flat_hp_partition],
203203
dp_group=self.real_dp_process_group[i])
204204

205+
def _lazy_init_hp_params_optimizer_state(self):
206+
if not self._hp_optimizer_states_linked:
207+
for i, _ in enumerate(self.optimizer.param_groups):
208+
lazy_init_hp_params_optimizer_state(self.bf16_groups[i], self.fp32_groups_flat_partition[i],
209+
self.optimizer.state)
210+
self._hp_optimizer_states_linked = True
211+
205212
def initialize_optimizer_states(self):
206213
"""Take an optimizer step with zero-valued gradients to allocate internal
207214
optimizer state.
@@ -215,8 +222,6 @@ def initialize_optimizer_states(self):
215222
param_partition.grad = grad_partition.to(
216223
param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition
217224

218-
self.optimizer.step()
219-
220225
if self.grad_acc_dtype is not torch.float32:
221226
for param_partition in self.fp32_groups_flat_partition:
222227
param_partition.grad = None
@@ -263,6 +268,9 @@ def step(self, closure=None):
263268

264269
self.optimizer.step()
265270

271+
# We need to link optimizer state after the first step() call
272+
self._lazy_init_hp_params_optimizer_state()
273+
266274
self.update_lp_params()
267275

268276
self.clear_hp_grads()

deepspeed/runtime/zero/stage3.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,10 +1016,6 @@ def initialize_optimizer_states(self):
10161016
else:
10171017
self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements)
10181018

1019-
# Initialize the optimizer states with the flattened fp32 partition.
1020-
if not is_adagrad:
1021-
self._optimizer_step(i)
1022-
10231019
if swappable_param_subgroup:
10241020
self._partitioned_params_swap_out(i)
10251021

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER,
3030
SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE,
3131
BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS)
32-
from deepspeed.utils import link_hp_params
32+
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state
3333
from deepspeed.checkpoint import enable_universal_checkpoint
3434

3535
from deepspeed.utils import groups
@@ -89,6 +89,12 @@ def _get_padded_tensor(src_tensor, size):
8989
return padded_tensor
9090

9191

92+
def _pad_tensor_by_size(src_tensor, pad_size, dtype, device):
93+
padded_tensor = torch.zeros(src_tensor.numel() + pad_size, dtype=dtype, device=device)
94+
padded_tensor.data[:src_tensor.numel()].copy_(src_tensor.data)
95+
return padded_tensor
96+
97+
9298
class DeepSpeedZeroOptimizer(ZeROOptimizer):
9399
"""
94100
DeepSpeedZeroOptimizer designed to reduce the memory footprint
@@ -537,6 +543,8 @@ def __init__(self,
537543
see_memory_usage(f"After initializing ZeRO optimizer", force=True)
538544

539545
self._link_all_hp_params()
546+
self._hp_optimizer_states_linked = False
547+
540548
self._enable_universal_checkpoint()
541549
self._param_slice_mappings = self._create_param_mapping()
542550

@@ -579,9 +587,15 @@ def _link_all_hp_params(self):
579587
param_group_index=i,
580588
partition_start=partition_id * partition_size,
581589
partition_size=partition_size,
582-
partition_optimizer_state=self.optimizer.state[flat_hp_partition],
583590
dp_group=self.real_dp_process_group[i])
584591

592+
def _lazy_init_hp_params_optimizer_state(self):
593+
if not self._hp_optimizer_states_linked:
594+
for i, _ in enumerate(self.optimizer.param_groups):
595+
lazy_init_hp_params_optimizer_state(self.bit16_groups[i], self.single_partition_of_fp32_groups[i],
596+
self.optimizer.state)
597+
self._hp_optimizer_states_linked = True
598+
585599
def is_moe_group(self, group):
586600
return 'moe' in group and group['moe']
587601

@@ -665,8 +679,6 @@ def initialize_optimizer_states(self):
665679
# which do lazy initialization of the state at the first call to step.
666680
if isinstance(self.optimizer, torch.optim.Adagrad):
667681
self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults)
668-
else:
669-
self.optimizer.step()
670682

671683
if not self.cpu_offload:
672684
for group in self.single_partition_of_fp32_groups:
@@ -1797,7 +1809,9 @@ def _optimizer_step(self, group_no):
17971809
self.optimizer.step()
17981810
self.optimizer.param_groups = original_param_groups
17991811

1800-
# @timeit
1812+
# We need to link optimizer state after the first step() call
1813+
self._lazy_init_hp_params_optimizer_state()
1814+
18011815
def step(self, closure=None):
18021816
"""
18031817
Not supporting closure.
@@ -2215,19 +2229,39 @@ def _partition_base_optimizer_state(self, state_key, all_partition_states, group
22152229
# Assume non-tensor states are not partitioned and equal across ranks, so return first one
22162230
return all_partition_states[0]
22172231

2218-
def _restore_base_optimizer_state(self, base_optimizer_group_states):
2232+
def _restore_step_from_elastic_checkpoint(self, all_state_dict):
2233+
assert BASE_OPTIMIZER_STATE_STEP in all_state_dict[0]
2234+
assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP]
2235+
for sd in all_state_dict), "State dicts of all partitions must have the same step value"
2236+
return all_state_dict[0][BASE_OPTIMIZER_STATE_STEP]
2237+
2238+
def _restore_base_optimizer_state(self, base_optimizer_group_states, base_optimizer_state_step, group_paddings):
22192239
if type(base_optimizer_group_states) == dict:
22202240
base_optimizer_group_states = base_optimizer_group_states['state']
2241+
2242+
saved_keys = base_optimizer_group_states[0].keys()
2243+
22212244
for i, group in enumerate(self.optimizer.param_groups):
22222245
p = group['params'][0]
2223-
for key, saved in base_optimizer_group_states[i].items():
2224-
if torch.is_tensor(self.optimizer.state[p][key]):
2225-
dst_tensor = self.optimizer.state[p][key]
2226-
src_tensor = _get_padded_tensor(saved, dst_tensor.numel())
2227-
self.optimizer.state[p][key].data.copy_(src_tensor.data)
2246+
padding = 0 if group_paddings is None else group_paddings[i]
2247+
for key in saved_keys:
2248+
saved = base_optimizer_group_states[i][key]
2249+
2250+
if torch.is_tensor(saved):
2251+
if key in self.optimizer.state[p]:
2252+
dst_tensor = self.optimizer.state[p][key]
2253+
src_tensor = _get_padded_tensor(saved, dst_tensor.numel())
2254+
self.optimizer.state[p][key].data.copy_(src_tensor.data)
2255+
else:
2256+
self.optimizer.state[p][key] = _pad_tensor_by_size(
2257+
saved, padding, torch.float32,
2258+
torch.device('cpu') if self.cpu_offload else self.device)
22282259
else:
22292260
self.optimizer.state[p][key] = saved
22302261

2262+
for param_group in self.optimizer.param_groups:
2263+
param_group['step'] = base_optimizer_state_step
2264+
22312265
def get_ep_ranks(self, rank=0, group_name=None):
22322266
from deepspeed.utils import groups
22332267
expert_parallel_size_ = groups._get_expert_parallel_world_size(group_name)
@@ -2255,15 +2289,8 @@ def _restore_elastic_base_optimizer_state(self, all_state_dict):
22552289
partition_states[key] = self._partition_base_optimizer_state(key, all_partition_states, i)
22562290
base_optimizer_group_states.append(partition_states)
22572291

2258-
self._restore_base_optimizer_state(base_optimizer_group_states)
2259-
2260-
# Restore step
2261-
if BASE_OPTIMIZER_STATE_STEP in all_state_dict[0]:
2262-
assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP]
2263-
for sd in all_state_dict), "State dicts of all partitions must have the same step value"
2264-
loaded_param_groups_step = all_state_dict[0][BASE_OPTIMIZER_STATE_STEP]
2265-
for param_group in self.optimizer.param_groups:
2266-
param_group['step'] = loaded_param_groups_step
2292+
self._restore_base_optimizer_state(base_optimizer_group_states,
2293+
self._restore_step_from_elastic_checkpoint(all_state_dict), None)
22672294

22682295
def load_state_dict(self,
22692296
state_dict_list,
@@ -2375,7 +2402,9 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l
23752402
self._restore_elastic_base_optimizer_state(state_dict_list)
23762403
else:
23772404
# loading an elastic checkpoint into rigid exec
2378-
self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE])
2405+
self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE],
2406+
current_rank_sd[BASE_OPTIMIZER_STATE_STEP],
2407+
current_rank_sd[GROUP_PADDINGS])
23792408

23802409
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
23812410
# The optimizer's hyperparameters and internal buffers are also up to date.

deepspeed/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@
1717
from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state
1818
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_optimizer_state
1919
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter
20-
from .mixed_precision_linkage import link_hp_params
20+
from .mixed_precision_linkage import link_hp_params, lazy_init_hp_params_optimizer_state
2121
from deepspeed.runtime.dataloader import RepeatingLoader
2222
from .numa import get_numactl_cmd

deepspeed/utils/mixed_precision_linkage.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,19 @@
99

1010

1111
def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
12-
param_group_index, partition_start, partition_size, partition_optimizer_state, dp_group):
12+
param_group_index, partition_start, partition_size, dp_group):
1313
local_lp_param_and_offset = _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group)
1414

1515
for lp_param, lp_start in local_lp_param_and_offset:
1616
lp_param._hp_mapping = get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict,
1717
offload_gradient_dict, use_offload, param_group_index,
18-
partition_start, partition_size, partition_optimizer_state)
18+
partition_start, partition_size)
19+
20+
21+
def lazy_init_hp_params_optimizer_state(lp_param_list, flat_hp_partition, optimizer_state):
22+
for lp in lp_param_list:
23+
if lp._hp_mapping is not None:
24+
lp._hp_mapping.set_optim_state_fragment(flat_hp_partition, optimizer_state[flat_hp_partition])
1925

2026

2127
def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group):

deepspeed/utils/tensor_fragment.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ class tensor_fragment:
2121
lp_fragment_address: fragment_address
2222
hp_fragment: torch.Tensor
2323
hp_fragment_address: fragment_address
24-
optim_fragment: Dict
2524
gradient_dict: Dict
2625
offload_gradient_dict: Dict
2726
use_offload: bool
2827
param_group_index: int
28+
optim_fragment: Dict = None
2929

3030
def update_hp(self):
3131
self.hp_fragment.data.copy_(self.lp_fragment.data)
@@ -39,6 +39,13 @@ def get_optim_state_fragment(self, key):
3939
else:
4040
raise ValueError(f'{key} not found in optimizer state fragment')
4141

42+
def set_optim_state_fragment(self, flat_hp_partition, optim_fragment):
43+
self.optim_fragment = {
44+
key: value.narrow(0, self.hp_fragment_address.start, self.hp_fragment_address.numel)
45+
for key, value in optim_fragment.items()
46+
if torch.is_tensor(value) and value.shape == flat_hp_partition.shape
47+
}
48+
4249
def get_hp_fragment_address(self):
4350
return self.hp_fragment_address
4451

@@ -255,7 +262,7 @@ def safe_set_local_fp32_param(param, value):
255262

256263

257264
def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
258-
param_group_index, partition_start, partition_size, optimizer_state_dict):
265+
param_group_index, partition_start, partition_size):
259266
lp_end = lp_param.numel() + lp_start
260267
hp_start = partition_start
261268
hp_end = partition_start + partition_size
@@ -268,11 +275,6 @@ def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict
268275
fragment_numel = fragment_end - fragment_start
269276
hp_frag_address = fragment_address(start=fragment_start - hp_start, numel=fragment_numel)
270277
hp_fragment_tensor = flat_hp_partition.narrow(0, hp_frag_address.start, hp_frag_address.numel)
271-
optim_fragment = {
272-
key: value.narrow(0, hp_frag_address.start, hp_frag_address.numel)
273-
for key, value in optimizer_state_dict.items()
274-
if torch.is_tensor(value) and value.shape == flat_hp_partition.shape
275-
}
276278

277279
lp_frag_address = fragment_address(start=fragment_start - lp_start, numel=fragment_numel)
278280
lp_fragment_tensor = lp_param.flatten().narrow(0, lp_frag_address.start, lp_frag_address.numel)
@@ -281,7 +283,6 @@ def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict
281283
lp_fragment_address=lp_frag_address,
282284
hp_fragment=hp_fragment_tensor,
283285
hp_fragment_address=hp_frag_address,
284-
optim_fragment=optim_fragment,
285286
gradient_dict=gradient_dict,
286287
offload_gradient_dict=offload_gradient_dict,
287288
use_offload=use_offload,

tests/unit/runtime/zero/test_zero.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,6 +1370,11 @@ class TestZeroAdamOptimizerStepCount(DistributedTest):
13701370
world_size = 1
13711371

13721372
def test(self, zero_stage):
1373+
# We verify trhee conditions:
1374+
# 1. global_steps starts at 0
1375+
# 2. All subgroups have the same step count
1376+
# 3. The global step count is the same as the step count of the first subgroup
1377+
13731378
# force all params to be partitioned by forcing threshold=0
13741379
config_dict = {
13751380
"train_micro_batch_size_per_gpu": 2,
@@ -1399,24 +1404,31 @@ def test(self, zero_stage):
13991404
model_parameters=model.parameters())
14001405
data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device)
14011406

1402-
for i, batch in enumerate(data_loader):
1407+
assert model.global_steps == 0
1408+
1409+
for batch in data_loader:
14031410
loss = model(batch[0], batch[1])
14041411
model.backward(loss)
1412+
1413+
is_gradient_accumulation_boundary = model.is_gradient_accumulation_boundary()
14051414
model.step()
14061415

1407-
step_counts = []
1408-
if zero_stage == 3:
1409-
for sub_group_id, _ in enumerate(optimizer.fp16_groups):
1410-
fp32_param = optimizer.fp32_partitioned_groups_flat[sub_group_id]
1411-
state = optimizer.optimizer.state[fp32_param]
1412-
step_counts.append(state["step"])
1413-
assert all(step == step_counts[0] for step in step_counts)
1414-
elif zero_stage == 1 or zero_stage == 2:
1415-
for param_group in optimizer.optimizer.param_groups:
1416-
for param in param_group["params"]:
1417-
state = optimizer.optimizer.state[param]
1416+
if is_gradient_accumulation_boundary:
1417+
step_counts = []
1418+
1419+
if zero_stage == 3:
1420+
for sub_group_id, _ in enumerate(optimizer.fp16_groups):
1421+
fp32_param = optimizer.fp32_partitioned_groups_flat[sub_group_id]
1422+
state = optimizer.optimizer.state[fp32_param]
14181423
step_counts.append(state["step"])
1424+
elif zero_stage == 1 or zero_stage == 2:
1425+
for param_group in optimizer.optimizer.param_groups:
1426+
for param in param_group["params"]:
1427+
state = optimizer.optimizer.state[param]
1428+
step_counts.append(state["step"])
1429+
14191430
assert all(step == step_counts[0] for step in step_counts)
1431+
assert model.global_steps == step_counts[0]
14201432

14211433

14221434
@pytest.mark.parametrize("zero_stage", [1, 2, 3])

0 commit comments

Comments
 (0)