Skip to content

Commit

Permalink
bugfix for offload_states
Browse files Browse the repository at this point in the history
bugfix for offload_states

Signed-off-by: Wei Wu <wuwei211x@gmail.com>
  • Loading branch information
U-rara committed Feb 18, 2025
1 parent 14b3cce commit 836b55b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
8 changes: 6 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2949,8 +2949,12 @@ def reload_states(self, non_blocking: bool = False):
self.lp_param_buffer.data = cpu_buffer.data.to(device, non_blocking=non_blocking)
self._set_fp16_partitioned_groups_flat()

for tensor, offset, tensor_numel in get_mapping_to_flat_buffer(
[p.ds_tensor for p in self.module.parameters()]):
parameter_partitions: List[Tensor] = []
for sub_group in self.fp16_groups:
for param in sub_group:
parameter_partitions.append(param.ds_tensor)

for tensor, offset, tensor_numel in get_mapping_to_flat_buffer(parameter_partitions):
tensor.data = self.lp_param_buffer.narrow(0, offset, tensor_numel)
self.offloaded_states.remove(OffloadStateTypeEnum.lp_params)

Expand Down
9 changes: 6 additions & 3 deletions tests/unit/runtime/zero/test_offload_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def compare_device(state) -> bool:
assert compare_device(state), f"State {state} is not on device {device}"


def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_blocking):
def run_model(model, param_groups, config_dict, hidden_dim, dtype, include, pin_memory, non_blocking):
# Currently we only support OffloadDeviceEnum.cpu
offload_device = OffloadDeviceEnum.cpu

model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=param_groups, config=config_dict)
data_loader = random_dataloader(model=model,
total_samples=10,
hidden_dim=hidden_dim,
Expand Down Expand Up @@ -124,5 +124,8 @@ def test_offload_states(self, included_state, pin_memory, non_blocking):
with deepspeed.zero.Init(config_dict_or_path=config_dict):
model = SimpleModel(hidden_dim, nlayers=4)

params = list(model.parameters())
param_groups = [{"params": params[::2], "weight_decay": 0.0}, {"params": params[1::2], "weight_decay": 0.1}]

include = None if included_state is None else [included_state]
run_model(model, config_dict, hidden_dim, torch.bfloat16, include, pin_memory, non_blocking)
run_model(model, param_groups, config_dict, hidden_dim, torch.bfloat16, include, pin_memory, non_blocking)

0 comments on commit 836b55b

Please sign in to comment.