Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2269,7 +2269,7 @@ def _prepare_fp32_grad_for_sub_group(self, sub_group_id):

assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \
"averaged gradients have different number of elements that partition size {} {} {} {}".format(
single_grad_partition.numel(), self.partition_size[sub_group_id], sub_group_id, partition_id)
single_grad_partition.numel(), self.fp32_partitioned_groups_flat[sub_group_id].numel(), sub_group_id, partition_id)

self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition

Expand Down Expand Up @@ -2638,14 +2638,12 @@ def get_groups_without_padding(self, groups_with_padding):
def _set_fp32_optimizer_param_groups(self):
for sub_group_id, _ in enumerate(self.fp16_groups):
param_group_id = self.sub_group_to_group_id[sub_group_id]
self.optimizer.param_groups[param_group_id]['params'] = [
self.fp32_partitioned_groups_flat[sub_group_id]
]
self.optimizer.param_groups[param_group_id]['params'].append(
self.fp32_partitioned_groups_flat[sub_group_id])

def _clear_fp32_optimizer_param_groups(self):
for sub_group_id, _ in enumerate(self.fp16_groups):
param_group_id = self.sub_group_to_group_id[sub_group_id]
self.optimizer.param_groups[param_group_id]['params'] = []
for param_group in self.optimizer.param_groups:
param_group['params'] = []

def _rigid_state_dict(self):
state_dict = {}
Expand Down
44 changes: 24 additions & 20 deletions tests/unit/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True):
if FP16_DeepSpeedZeroOptimizer_Stage3 is not None and isinstance(
saved_model.optimizer,
FP16_DeepSpeedZeroOptimizer_Stage3):
for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
for p0, p1 in zip(saved_model.optimizer.fp32_partitioned_groups_flat, loaded_model.optimizer.fp32_partitioned_groups_flat):
assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"

elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer):
Expand Down Expand Up @@ -303,12 +303,13 @@ def _test_checkpoint_fused_optimizer(args,
'deepspeed_adam'),
(3,
False,
'Adam')])
'Adam'),
(3,
True,
'deepspeed_adam')])
def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
if zero_stage == 3:
pytest.skip('Skip checkpointing tests for ZeRO3')

config_dict = {
"train_batch_size": 2,
Expand All @@ -324,8 +325,10 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_opt
}
},
"fp16": {
"enabled": True
"enabled": True,
"initial_scale_power": 8
},
"wall_clock_breakdown": True,
"zero_optimization": {
"stage": zero_stage,
"cpu_offload": use_cpu_offload
Expand All @@ -340,9 +343,7 @@ def _test_checkpoint_zero_optimizer(args,
hidden_dim,
load_optimizer_states):
if zero_stage == 3:
global FP16_DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
with deepspeed.ScatteredParameters(zero_modules=True):
with deepspeed.zero.Init():
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
else:
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
Expand Down Expand Up @@ -371,15 +372,16 @@ def _test_checkpoint_zero_optimizer(args,
'deepspeed_adam'),
(3,
False,
'Adam')])
'Adam'),
(3,
True,
'deepspeed_adam')])
def test_checkpoint_zero_no_optimizer(tmpdir,
zero_stage,
use_cpu_offload,
adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
if zero_stage == 3:
pytest.skip('Skip checkpointing tests for ZeRO3')

config_dict = {
"train_batch_size": 2,
Expand Down Expand Up @@ -413,7 +415,7 @@ def _test_checkpoint_zero_no_optimizer(args,
if zero_stage == 3:
global FP16_DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
with deepspeed.ScatteredParameters(zero_modules=True):
with deepspeed.zero.Init():
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
else:
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
Expand Down Expand Up @@ -445,12 +447,13 @@ def _test_checkpoint_zero_no_optimizer(args,
'deepspeed_adam'),
(3,
False,
'Adam')])
'Adam'),
(3,
True,
'deepspeed_adam')])
def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
if zero_stage == 3:
pytest.skip('Skip checkpointing tests for ZeRO3')

config_dict = {
"train_batch_size": 2,
Expand Down Expand Up @@ -493,7 +496,7 @@ def _test_checkpoint_lr_scheduler(args,
if zero_stage == 3:
global FP16_DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
with deepspeed.ScatteredParameters(zero_modules=True):
with deepspeed.zero.Init():
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
else:
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
Expand Down Expand Up @@ -526,14 +529,15 @@ def _test_checkpoint_lr_scheduler(args,
(2,
True,
'deepspeed_adam'),
(3,
False,
'Adam'),
(3,
True,
'Adam')])
'deepspeed_adam')])
def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
if zero_stage == 3:
pytest.skip('Skip checkpointing tests for ZeRO3')

config_dict = {
"train_batch_size": 2,
Expand Down Expand Up @@ -570,7 +574,7 @@ def _test_checkpoint_no_lr_scheduler(args,
load_optimizer_states,
load_lr_scheduler_states):
if zero_stage == 3:
with deepspeed.ScatteredParameters(zero_modules=True):
with deepspeed.zero.Init():
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
else:
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
Expand Down