Skip to content
Merged
8 changes: 6 additions & 2 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,8 +807,12 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
if start < param.ds_numel:
elements = min(param.ds_numel - start, partition_size)

dest_tensor = partition_buffer.view(-1).narrow(0, 0, elements)
dest_tensor_full_buffer = partition_buffer.view(-1).narrow(
0,
0,
partition_size)

dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements)
src_tensor = param.grad.view(-1).narrow(0, start, elements)

# just copy the grad partition to the buffer
Expand Down Expand Up @@ -841,7 +845,7 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
# elements))

#print("after partition gradients")
param.grad.data = dest_tensor.data
param.grad.data = dest_tensor_full_buffer.data
see_memory_usage("After partitioning gradients", force=False)


Expand Down
21 changes: 15 additions & 6 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,10 +961,9 @@ def _create_fp16_partitions_with_defragmentation(self):

#create flat buffer in CPU and move to GPU
self.fp16_partitioned_groups_flat.append(
flatten_dense_tensors_aligned(
self.fp16_partitioned_groups[i],
dist.get_world_size(group=self.dp_process_group)).cuda(
torch.cuda.current_device()))
flatten_dense_tensors_aligned(self.fp16_partitioned_groups[i],
1).cuda(
torch.cuda.current_device()))
see_memory_usage(
f"After flattening and moving param group {i} to GPU",
force=False)
Expand All @@ -976,10 +975,12 @@ def _create_fp16_partitions_with_defragmentation(self):
flat_offset,
total_elements)
self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat)
self._move_to_flat_buffer(self.fp16_partitioned_groups[i],
self.fp16_partitioned_groups_flat[i])
flat_offset += total_elements

# move param to flat buffer for both param offload on/off
self._move_to_flat_buffer(self.fp16_partitioned_groups[i],
self.fp16_partitioned_groups_flat[i])

see_memory_usage(f"After Flattening param group {i}", force=False)

def _create_fp32_partitions(self):
Expand Down Expand Up @@ -1036,6 +1037,14 @@ def setup_zero_stage3_hooks(self):
self.hierarchy = 0
self._register_hooks_recursively(self.module)

#reset step if in inference mode
def _end_of_forward_hook(module, *args):

if not torch._C.is_grad_enabled():
self.param_coordinator.reset_step()

self.module.register_forward_hook(_end_of_forward_hook)

def persistent_parameters(self):
persistent_params = []
total_persistent_parameters = 0
Expand Down
13 changes: 7 additions & 6 deletions tests/unit/test_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,6 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")

if zero_stage == 3:
pytest.skip("skip for now")

config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
Expand All @@ -371,8 +368,9 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload):
args = args_from_dict(tmpdir, config_dict)

@distributed_test(world_size=2)
def _test_zero_static_scale(args, zero_stage):
hidden_dim = 10
def _test_zero_static_scale(args, zero_stage, hidden_dim):
#making hidden size not divisible by DP for covering this scenario
hidden_dim = hidden_dim
model = SimpleModel(hidden_dim)

model, optim, _, _ = deepspeed.initialize(args=args,
Expand All @@ -393,7 +391,10 @@ def _test_zero_static_scale(args, zero_stage):
model.backward(loss)
model.step()

_test_zero_static_scale(args=args, zero_stage=zero_stage)
#test when hidden_dim is not aligned with world size
_test_zero_static_scale(args=args, zero_stage=zero_stage, hidden_dim=9)
#test when hidden_dim is aligned with world size
_test_zero_static_scale(args=args, zero_stage=zero_stage, hidden_dim=10)


def test_zero_static_scale_deprecated_format(tmpdir):
Expand Down