Skip to content

Commit cd20a3b

Browse files
wenbinc-Bintjruwaseloadamshwchen2017
authored
Fix potential memory issues when use deepspeed Z3 (#6726)
I had OOM problem when doing DPO training using zero3. It needs to call module twice in one training step, and second call is with no_grad(). The problem is caused by two bugs: 1. "__n_available_params", which helps to control fetched parameters, becomes negative after release_and_reset_all() function. 2. module.ds_grads_remaining becomes negative in backward() if we call module more than once in one training step. I tried to create two patches to fix these issues. --------- Signed-off-by: Wenbin Chen <wenbin.chen@intel.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
1 parent f515104 commit cd20a3b

File tree

3 files changed

+56
-3
lines changed

3 files changed

+56
-3
lines changed

deepspeed/runtime/zero/parameter_offload.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,8 @@ def _run_before_forward_function(input):
392392
_run_after_backward_hook, inputs)
393393

394394
def _post_backward_module_hook(module, inputs):
395-
module.ds_grads_remaining = 0
395+
if not hasattr(module, "ds_grads_remaining"):
396+
module.ds_grads_remaining = 0
396397

397398
if not hasattr(module, "post_bwd_fn"):
398399

deepspeed/runtime/zero/partitioned_param_coordinator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,6 @@ def reset_step(self) -> None:
252252
self.__most_recent_step_id_param_fetched_for = collections.defaultdict(lambda: int(-1e10))
253253
self.__step_id_module_fetched_for = collections.defaultdict(lambda: collections.deque())
254254
self.__step_id = 0
255-
self.__n_available_params = 0
256255
self.__profiler.reset_events()
257256

258257
def _dump_params(self, tag, sub_module, params, step_id=None):
@@ -430,7 +429,7 @@ def release_and_reset_all(self, module: Module) -> None:
430429
# there's a hook execution issue
431430
param.ds_active_sub_modules.clear()
432431
self.__release_param(param)
433-
432+
self.__n_available_params = 0
434433
for param in iter_params(module, recurse=True):
435434
if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
436435
raise RuntimeError(f"{param.ds_summary()} expected to be released")
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import deepspeed
7+
import torch
8+
from unit.common import DistributedTest, preferred_dtype
9+
from unit.simple_model import SimpleModel, random_dataloader
10+
11+
12+
class TestZ3MultipleModelCall(DistributedTest):
13+
world_size = 1
14+
15+
def test_z3_multiple_model_call(self):
16+
config_dict = {
17+
"train_micro_batch_size_per_gpu": 1,
18+
"gradient_accumulation_steps": 1,
19+
"steps_per_print": 1,
20+
"zero_optimization": {
21+
"stage": 3
22+
},
23+
"optimizer": {
24+
"type": "Adam",
25+
"params": {
26+
"lr": 1e-3
27+
}
28+
},
29+
}
30+
if preferred_dtype() is torch.float16:
31+
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
32+
elif preferred_dtype() is torch.bfloat16:
33+
config_dict["bf16"] = {"enabled": True}
34+
hidden_dim, nlayers = 2048, 3
35+
model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayers)
36+
model_engine, _, _, _ = deepspeed.initialize(config=config_dict,
37+
model=model,
38+
model_parameters=model.parameters())
39+
data_loader = iter(
40+
random_dataloader(model=model_engine, total_samples=10, hidden_dim=hidden_dim, device=model_engine.device))
41+
42+
for n, batch in enumerate(data_loader):
43+
loss1 = model_engine(batch[0], batch[1])
44+
with torch.no_grad():
45+
loss2 = model_engine(batch[0], batch[1])
46+
loss = loss1 + loss2
47+
model_engine.backward(loss)
48+
for name, submodule in model_engine.module.linears._modules.items():
49+
assert hasattr(submodule, "ds_grads_remaining"), \
50+
f"linears.{name} does not have variable ds_grads_remaining"
51+
assert submodule.ds_grads_remaining == 0, \
52+
f"ds_grads_remaining of linears.{name} is not 0 ({submodule.ds_grads_remaining})"
53+
model_engine.step()

0 commit comments

Comments
 (0)