Skip to content

Commit 3ed6708

Browse files
tjruwasetraincheck-team
authored andcommitted
Precisely track nvme optimizer offload (deepspeedai#6963)
Fix deepspeedai#4998
1 parent 5a597bc commit 3ed6708

File tree

4 files changed

+16
-15
lines changed

4 files changed

+16
-15
lines changed

deepspeed/runtime/engine.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -799,10 +799,8 @@ def zero_load_from_fp32_weights(self):
799799
def zero_elastic_checkpoint(self):
800800
return self._config.zero_config.elastic_checkpoint
801801

802-
def zero_has_nvme_offload(self):
803-
if not hasattr(self.optimizer, "swap_optimizer"):
804-
return False
805-
return self.optimizer.swap_optimizer or self.optimizer.params_in_nvme_and_cpu
802+
def zero_nvme_offload_optimizer(self):
803+
return getattr(self.optimizer, "swap_optimizer", False)
806804

807805
def zero_max_live_parameters(self):
808806
return self._config.zero_config.max_live_parameters
@@ -2865,7 +2863,7 @@ def load_checkpoint(self,
28652863
if not success:
28662864
self.optimizer._restore_from_bit16_weights()
28672865

2868-
if self.zero_has_nvme_offload():
2866+
if self.zero_nvme_offload_optimizer():
28692867
from shutil import copytree, disk_usage
28702868
offload_dir = self.optimizer.optimizer_swapper.swap_folder
28712869
offload_ckpt_dir = os.path.join(load_dir, tag, "offloaded_tensors")
@@ -3205,7 +3203,7 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True,
32053203
self._create_zero_checkpoint_files(save_dir, tag)
32063204
self._save_zero_checkpoint(save_dir, tag)
32073205

3208-
if self.zero_has_nvme_offload():
3206+
if self.zero_nvme_offload_optimizer():
32093207
from shutil import copytree, disk_usage
32103208
offload_dir = self.optimizer.optimizer_swapper.swap_folder
32113209
offload_ckpt_dir = os.path.join(save_dir, tag, "offloaded_tensors")

deepspeed/runtime/swap_tensor/optimizer_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,11 @@ def __init__(self, swap_config, aio_config, base_folder, optimizer, largest_nume
153153
'timer_names',
154154
]
155155

156+
def purge_state(self):
157+
for swap_info in self.swap_params_info.values():
158+
swap_info.tensors = [swap_info.tensors[0]]
159+
swap_info.has_state_tensors = False
160+
156161
def swappable_tensor(self, param=None, numel=None):
157162
assert param is not None or numel is not None, "Either param or numel must be provided"
158163
if param is not None:

deepspeed/runtime/zero/stage3.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2652,11 +2652,9 @@ def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
26522652
self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
26532653
self._clear_fp32_optimizer_param_groups()
26542654

2655-
if self.swap_optimizer or self.params_in_nvme_and_cpu:
2655+
if self.swap_optimizer:
26562656
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
2657-
for swap_info in self.optimizer_swapper.swap_params_info.values():
2658-
swap_info.tensors = [swap_info.tensors[0]]
2659-
swap_info.has_state_tensors = False
2657+
self.optimizer_swapper.purge_state()
26602658

26612659
if self.swap_optimizer:
26622660
# Touch all parameters to synchronize all buffers
@@ -2773,11 +2771,9 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
27732771
else:
27742772
optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor
27752773

2776-
if self.swap_optimizer or self.params_in_nvme_and_cpu:
2774+
if self.swap_optimizer:
27772775
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
2778-
for swap_info in self.optimizer_swapper.swap_params_info.values():
2779-
swap_info.tensors = [swap_info.tensors[0]]
2780-
swap_info.has_state_tensors = False
2776+
self.optimizer_swapper.purge_state()
27812777

27822778
if self.swap_optimizer:
27832779
# Touch all parameters to synchronize all buffers

tests/unit/runtime/zero/test_nvme_checkpointing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ class TestNVMeCheckpointing(DistributedTest):
2222
world_size = 1
2323

2424
@pytest.mark.parametrize('param_offload_device, optim_offload_device',
25-
[(OffloadDeviceEnum.cpu, OffloadDeviceEnum.cpu),
25+
[(OffloadDeviceEnum.none, OffloadDeviceEnum.nvme),
2626
(OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme),
27+
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.none),
28+
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.cpu),
2729
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.nvme)])
2830
def test_nvme_checkpointing(self, tmpdir, param_offload_device, optim_offload_device):
2931
zero_dir, ckpt_dir = os.path.join(tmpdir, "zero"), os.path.join(tmpdir, "checkpoint")

0 commit comments

Comments
 (0)