Skip to content

Commit 4e47e1d

Browse files
tjruwaseamaurya
authored andcommitted
Precisely track nvme optimizer offload (deepspeedai#6963)
Fix deepspeedai#4998
1 parent 854982e commit 4e47e1d

File tree

4 files changed

+16
-15
lines changed

4 files changed

+16
-15
lines changed

deepspeed/runtime/engine.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -799,10 +799,8 @@ def zero_load_from_fp32_weights(self):
799799
def zero_elastic_checkpoint(self):
800800
return self._config.zero_config.elastic_checkpoint
801801

802-
def zero_has_nvme_offload(self):
803-
if not hasattr(self.optimizer, "swap_optimizer"):
804-
return False
805-
return self.optimizer.swap_optimizer or self.optimizer.params_in_nvme_and_cpu
802+
def zero_nvme_offload_optimizer(self):
803+
return getattr(self.optimizer, "swap_optimizer", False)
806804

807805
def zero_max_live_parameters(self):
808806
return self._config.zero_config.max_live_parameters
@@ -2876,7 +2874,7 @@ def load_checkpoint(self,
28762874
if not success:
28772875
self.optimizer._restore_from_bit16_weights()
28782876

2879-
if self.zero_has_nvme_offload():
2877+
if self.zero_nvme_offload_optimizer():
28802878
from shutil import copytree, disk_usage
28812879
offload_dir = self.optimizer.optimizer_swapper.swap_folder
28822880
offload_ckpt_dir = os.path.join(load_dir, tag, "offloaded_tensors")
@@ -3216,7 +3214,7 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True,
32163214
self._create_zero_checkpoint_files(save_dir, tag)
32173215
self._save_zero_checkpoint(save_dir, tag)
32183216

3219-
if self.zero_has_nvme_offload():
3217+
if self.zero_nvme_offload_optimizer():
32203218
from shutil import copytree, disk_usage
32213219
offload_dir = self.optimizer.optimizer_swapper.swap_folder
32223220
offload_ckpt_dir = os.path.join(save_dir, tag, "offloaded_tensors")

deepspeed/runtime/swap_tensor/optimizer_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,11 @@ def __init__(self, swap_config, aio_config, base_folder, optimizer, largest_nume
153153
'timer_names',
154154
]
155155

156+
def purge_state(self):
157+
for swap_info in self.swap_params_info.values():
158+
swap_info.tensors = [swap_info.tensors[0]]
159+
swap_info.has_state_tensors = False
160+
156161
def swappable_tensor(self, param=None, numel=None):
157162
assert param is not None or numel is not None, "Either param or numel must be provided"
158163
if param is not None:

deepspeed/runtime/zero/stage3.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2652,11 +2652,9 @@ def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
26522652
self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
26532653
self._clear_fp32_optimizer_param_groups()
26542654

2655-
if self.swap_optimizer or self.params_in_nvme_and_cpu:
2655+
if self.swap_optimizer:
26562656
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
2657-
for swap_info in self.optimizer_swapper.swap_params_info.values():
2658-
swap_info.tensors = [swap_info.tensors[0]]
2659-
swap_info.has_state_tensors = False
2657+
self.optimizer_swapper.purge_state()
26602658

26612659
if self.swap_optimizer:
26622660
# Touch all parameters to synchronize all buffers
@@ -2773,11 +2771,9 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
27732771
else:
27742772
optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor
27752773

2776-
if self.swap_optimizer or self.params_in_nvme_and_cpu:
2774+
if self.swap_optimizer:
27772775
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
2778-
for swap_info in self.optimizer_swapper.swap_params_info.values():
2779-
swap_info.tensors = [swap_info.tensors[0]]
2780-
swap_info.has_state_tensors = False
2776+
self.optimizer_swapper.purge_state()
27812777

27822778
if self.swap_optimizer:
27832779
# Touch all parameters to synchronize all buffers

tests/unit/runtime/zero/test_nvme_checkpointing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ class TestNVMeCheckpointing(DistributedTest):
2222
world_size = 1
2323

2424
@pytest.mark.parametrize('param_offload_device, optim_offload_device',
25-
[(OffloadDeviceEnum.cpu, OffloadDeviceEnum.cpu),
25+
[(OffloadDeviceEnum.none, OffloadDeviceEnum.nvme),
2626
(OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme),
27+
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.none),
28+
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.cpu),
2729
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.nvme)])
2830
def test_nvme_checkpointing(self, tmpdir, param_offload_device, optim_offload_device):
2931
zero_dir, ckpt_dir = os.path.join(tmpdir, "zero"), os.path.join(tmpdir, "checkpoint")

0 commit comments

Comments
 (0)