Skip to content

Commit 18ee6cf

Browse files
committed
Add API for updating ZeRO gradients
1 parent 2a56f53 commit 18ee6cf

File tree

5 files changed

+139
-36
lines changed

5 files changed

+139
-36
lines changed

deepspeed/runtime/zero/stage3.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2296,6 +2296,24 @@ def get_fp32_grad_for_param(self, param) -> Tensor:
22962296

22972297
return self._fp32_state_allgather(param, fp32_grad)
22982298

2299+
def set_fp32_grad_for_param(self, value, param):
2300+
if not param.requires_grad:
2301+
return
2302+
2303+
if not get_accelerator().resolves_data_dependency():
2304+
self.reduce_and_partition_stream.synchronize()
2305+
2306+
if self.offload_optimizer:
2307+
group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)]
2308+
fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements)
2309+
else:
2310+
fp32_grad = self.__param_id_to_grad_partition[param.ds_id]
2311+
2312+
my_rank = dist.get_rank(group=self.dp_process_group)
2313+
value_partition = value.flatten().narrow(0, fp32_grad.numel() * my_rank, fp32_grad.numel())
2314+
2315+
fp32_grad.data.copy_(value_partition.data)
2316+
22992317
def _get_fp32_opt_state_partition(self, param, optim_state_key=None):
23002318
if not get_accelerator().resolves_data_dependency():
23012319
self.reduce_and_partition_stream.synchronize()
@@ -2344,12 +2362,6 @@ def set_full_hp_param(self, value, param, optim_state_key=None):
23442362

23452363
### Local API START ###
23462364

2347-
def get_local_fp32_param(self, param, optim_state_key=None) -> Tensor:
2348-
if not param.requires_grad:
2349-
return None
2350-
fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key)
2351-
return fp32_opt_state
2352-
23532365
def get_local_fp32_grad_for_param(self, param) -> Tensor:
23542366
if not param.requires_grad:
23552367
return None
@@ -2364,6 +2376,30 @@ def get_local_fp32_grad_for_param(self, param) -> Tensor:
23642376
fp32_grad = self.__param_id_to_grad_partition[param.ds_id].float()
23652377
return fp32_grad
23662378

2379+
def set_local_grad_for_param(self, value, param):
2380+
if not param.requires_grad:
2381+
return
2382+
2383+
assert value.numel() == param.ds_tensor.numel(
2384+
), f" Number of elements do not match: {value.numel()} != {param.ds_tensor.ds_numel}"
2385+
2386+
if not get_accelerator().resolves_data_dependency():
2387+
self.reduce_and_partition_stream.synchronize()
2388+
2389+
if self.offload_optimizer:
2390+
group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)]
2391+
fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements)
2392+
else:
2393+
fp32_grad = self.__param_id_to_grad_partition[param.ds_id]
2394+
2395+
fp32_grad.data.copy_(value.flatten().data)
2396+
2397+
def get_local_fp32_param(self, param, optim_state_key=None) -> Tensor:
2398+
if not param.requires_grad:
2399+
return None
2400+
fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key)
2401+
return fp32_opt_state
2402+
23672403
def set_local_hp_param(self, value, param, optim_state_key=None):
23682404
if not param.requires_grad:
23692405
return
@@ -2378,7 +2414,7 @@ def set_local_hp_param(self, value, param, optim_state_key=None):
23782414

23792415
if self._swappable_optimizer_subgroup(group_idx):
23802416
self._optimizer_states_and_gradient_swap_out(group_idx)
2381-
logger.info(f"[set_local_hp_param][update the params' value successfully]")
2417+
# logger.info(f"[set_local_hp_param][update the params' value successfully]")
23822418

23832419
### Local API END ###
23842420

deepspeed/utils/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
# TODO: Move tensor fragment and mixed precision to zero utils
1313
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad, map_to_flat_opt_states
1414
from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
15-
from .tensor_fragment import set_full_hp_param
16-
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state
15+
from .tensor_fragment import set_full_hp_param, set_full_hp_grad
16+
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state, safe_set_full_grad
1717
from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state
18-
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_optimizer_state
18+
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_grad, safe_set_local_optimizer_state
1919
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter
2020
from .mixed_precision_linkage import link_hp_params, lazy_init_hp_params_optimizer_state
2121
from deepspeed.runtime.dataloader import RepeatingLoader

deepspeed/utils/mixed_precision_linkage.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import types
77
from deepspeed.utils import get_full_hp_param, get_full_hp_grad, get_hp_fragment_mapping
8-
from deepspeed.utils import set_full_hp_param
8+
from deepspeed.utils import set_full_hp_param, set_full_hp_grad
99

1010

1111
def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
@@ -35,6 +35,7 @@ def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_gr
3535
lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param)
3636
lp_param.get_full_hp_grad = types.MethodType(get_full_hp_grad, lp_param)
3737
lp_param.set_full_hp_param = types.MethodType(set_full_hp_param, lp_param)
38+
lp_param.set_full_hp_grad = types.MethodType(set_full_hp_grad, lp_param)
3839

3940
# lp_param overlaps with partition if both are true
4041
# 1) current_offset < partition_end,

deepspeed/utils/tensor_fragment.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,17 @@ def get_hp_fragment(self, optim_state_key=None):
5757
return self.hp_fragment
5858
return self.get_optim_state_fragment(optim_state_key)
5959

60+
def get_lp_grad_fragment(self, index_in_param_group):
61+
if self.use_offload:
62+
gradient_dict = self.offload_gradient_dict
63+
else:
64+
gradient_dict = self.gradient_dict
65+
66+
if self.param_group_index not in gradient_dict or gradient_dict[self.param_group_index] is None:
67+
raise ValueError("Gradients are only available immediately after backward and before engine step")
68+
69+
return gradient_dict[self.param_group_index][index_in_param_group]
70+
6071

6172
def map_to_flat_opt_states(flat_hp_tensor, lp_tensors, optim_state, opt_keys):
6273
for key in opt_keys:
@@ -95,17 +106,7 @@ def set_full_hp_param(self, value, optim_state_key=None):
95106
def get_full_hp_grad(self):
96107
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
97108
if self._hp_mapping is not None:
98-
hp_mapping = self._hp_mapping
99-
100-
if hp_mapping.use_offload:
101-
gradient_dict = hp_mapping.offload_gradient_dict
102-
else:
103-
gradient_dict = hp_mapping.gradient_dict
104-
105-
if hp_mapping.param_group_index not in gradient_dict or gradient_dict[hp_mapping.param_group_index] is None:
106-
raise ValueError("Gradients are only available immediately after backward and before engine step")
107-
108-
lp_grad_fragment = gradient_dict[hp_mapping.param_group_index][self._index_in_param_group]
109+
lp_grad_fragment = self._hp_mapping.get_lp_grad_fragment(self._index_in_param_group)
109110
hp_grad_fragment = lp_grad_fragment.to(torch.float32).flatten()
110111

111112
lp_frag_address = self._hp_mapping.lp_fragment_address
@@ -120,6 +121,14 @@ def get_full_hp_grad(self):
120121
return reduce_buffer.reshape_as(self)
121122

122123

124+
def set_full_hp_grad(self, value):
125+
if self._hp_mapping is not None:
126+
lp_grad_fragment = self._hp_mapping.get_lp_grad_fragment(self._index_in_param_group)
127+
lp_frag_address = self._hp_mapping.lp_fragment_address
128+
value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel)
129+
lp_grad_fragment.data.copy_(value_fragment.data.reshape_as(lp_grad_fragment.data))
130+
131+
123132
def safe_get_full_fp32_param(param):
124133
"""Assemble and return the fp32 parameter of a low-precision (e.g., fp16) parameter.
125134
@@ -207,6 +216,26 @@ def safe_get_full_grad(param):
207216
return None
208217

209218

219+
def safe_set_full_grad(param, value):
220+
"""Update the partitioned gradient of a low-precision (e.g., fp16) parameter.
221+
222+
Args:
223+
param (``torch.nn.Parameter``): A model parameter
224+
value (``torch.Tensor``): New value
225+
"""
226+
if param.grad is not None:
227+
param.grad.copy_(value)
228+
return
229+
230+
# ZeRO stage 3 param
231+
if hasattr(param, 'ds_id'):
232+
param._z3_optimizer.set_fp32_grad_for_param(value, param)
233+
234+
# ZeRO stage 1, 2, and bf16_optimizer params
235+
if hasattr(param, '_hp_mapping'):
236+
param.set_full_hp_grad(value)
237+
238+
210239
### Local API START ###
211240
def safe_get_local_grad(param):
212241
"""Get the fp32 gradient of a partitioned parameter.
@@ -223,6 +252,22 @@ def safe_get_local_grad(param):
223252
return None
224253

225254

255+
def safe_set_local_grad(param, value):
256+
"""Update the gradient of a partitioned parameter.
257+
Args:
258+
param (``torch.nn.Parameter``): A model parameter
259+
value (``torch.Tensor``): New value
260+
"""
261+
if param.grad is not None:
262+
return param.grad.copy_(value)
263+
264+
# ZeRO stage 3 param
265+
if hasattr(param, 'ds_id'):
266+
return param._z3_optimizer.set_local_grad_for_param(value, param)
267+
268+
return None
269+
270+
226271
def safe_get_local_fp32_param(param):
227272
"""Get the fp32 partitioned parameter.
228273
Args:

tests/unit/runtime/zero/test_zero_tensor_fragment.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@
1313

1414
import deepspeed
1515
from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
16-
from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state
16+
from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_grad, safe_set_full_optimizer_state
1717
from deepspeed.utils import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state
18-
from deepspeed.utils import safe_set_local_fp32_param, safe_set_local_optimizer_state
18+
from deepspeed.utils import safe_set_local_fp32_param, safe_set_local_grad, safe_set_local_optimizer_state
1919
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
2020
from deepspeed.ops.aio import AsyncIOBuilder
2121
from deepspeed.accelerator import get_accelerator
2222

2323
WEIGHT_KEY = 'weight'
2424
FIRST_ORDER_KEY = 'exp_avg'
2525
SECOND_ORDER_KEY = 'exp_avg_sq'
26+
GRADIENT_KEY = 'gradient'
2627

2728

2829
def validate_tensor(model, api_type, opt_states):
@@ -180,13 +181,14 @@ def test_bf16_fragments(self, frozen_weights):
180181
run_fragmented_model(model, config_dict, hidden_dim, torch.bfloat16, validate_after_bwd, validate_after_step)
181182

182183

183-
def create_random_values(model, key_list, group, use_cuda=True):
184+
def create_random_values(model, key_list, group, grad_dtype, use_cuda=True):
184185
param_values = {}
185186
for n, lp in model.named_parameters():
186187
param_shape = lp.ds_shape if hasattr(lp, 'ds_id') else lp.shape
187188
param_values[n] = {}
188189
for key in key_list:
189-
rand_value = torch.rand(param_shape, dtype=torch.float32, device=model.device)
190+
dtype = grad_dtype if key == GRADIENT_KEY else torch.float32
191+
rand_value = torch.rand(param_shape, dtype=dtype, device=model.device)
190192
dist.broadcast(rand_value, src=0, group=group)
191193
param_values[n][key] = rand_value
192194
return param_values
@@ -195,7 +197,9 @@ def create_random_values(model, key_list, group, use_cuda=True):
195197
def set_param_values_with_dict(model, value_dict):
196198
for n, lp in model.named_parameters():
197199
for key, value_tensor in value_dict[n].items():
198-
if key == WEIGHT_KEY:
200+
if key == GRADIENT_KEY:
201+
safe_set_full_grad(lp, value_tensor)
202+
elif key == WEIGHT_KEY:
199203
safe_set_full_fp32_param(lp, value_tensor)
200204
else:
201205
safe_set_full_optimizer_state(lp, value_tensor, key)
@@ -204,21 +208,25 @@ def set_param_values_with_dict(model, value_dict):
204208
def validate_param_values_with_dict(model, value_dict):
205209
for n, lp in model.named_parameters():
206210
for key, expected_tensor in value_dict[n].items():
207-
if key == WEIGHT_KEY:
211+
if key == GRADIENT_KEY:
212+
actual_tensor = safe_get_full_grad(lp)
213+
elif key == WEIGHT_KEY:
208214
actual_tensor = safe_get_full_fp32_param(lp)
209215
else:
210216
actual_tensor = safe_get_full_optimizer_state(lp, key)
217+
211218
assert torch.equal(expected_tensor, actual_tensor)
212219

213220

214-
def create_random_values_for_local(model, key_list, group, use_cuda=True):
221+
def create_random_values_for_local(model, key_list, group, grad_dtype, use_cuda=True):
215222
param_values = {}
216223
for n, lp in model.named_parameters():
217224
param_shape = lp.ds_tensor.shape
218225
param_values[n] = {}
219226
for key in key_list:
220227
device = model.device if use_cuda else "cpu"
221-
rand_value = torch.rand(param_shape, dtype=torch.float32, device=device)
228+
dtype = grad_dtype if key == GRADIENT_KEY else torch.float32
229+
rand_value = torch.rand(param_shape, dtype=dtype, device=device)
222230
# dist.broadcast(rand_value, src=0, group=group)
223231
param_values[n][key] = rand_value
224232
return param_values
@@ -228,7 +236,9 @@ def set_local_param_values_with_dict(model, value_dict):
228236
for n, lp in model.named_parameters():
229237

230238
for key, value_tensor in value_dict[n].items():
231-
if key == WEIGHT_KEY:
239+
if key == GRADIENT_KEY:
240+
safe_set_local_grad(lp, value_tensor)
241+
elif key == WEIGHT_KEY:
232242
safe_set_local_fp32_param(lp, value_tensor)
233243
else:
234244
safe_set_local_optimizer_state(lp, value_tensor, key)
@@ -237,10 +247,13 @@ def set_local_param_values_with_dict(model, value_dict):
237247
def validate_local_param_values_with_dict(model, value_dict):
238248
for n, lp in model.named_parameters():
239249
for key, expected_tensor in value_dict[n].items():
240-
if key == WEIGHT_KEY:
250+
if key == GRADIENT_KEY:
251+
actual_tensor = safe_get_local_grad(lp)
252+
elif key == WEIGHT_KEY:
241253
actual_tensor = safe_get_local_fp32_param(lp)
242254
else:
243255
actual_tensor = safe_get_local_optimizer_state(lp, key)
256+
244257
assert torch.equal(expected_tensor, actual_tensor)
245258

246259

@@ -325,12 +338,20 @@ def test_zero_fragments(self, tmpdir, api_type, zero_stage, offload_device, dtyp
325338

326339
dist.barrier()
327340

328-
def validate_func(model):
329-
optim_keys = [WEIGHT_KEY, FIRST_ORDER_KEY, SECOND_ORDER_KEY]
341+
def after_bwd_validate_func(model):
342+
state_keys = [WEIGHT_KEY, GRADIENT_KEY]
343+
helper_funcs = helper_funcs_mapping[api_type]
344+
optim_state_values = helper_funcs["create_random_values"](
345+
model, state_keys, group, grad_dtype=dtype, use_cuda=offload_device == OffloadDeviceEnum.none)
346+
helper_funcs["set_param_values_with_dict"](model, optim_state_values)
347+
helper_funcs["validate_param_values_with_dict"](model, optim_state_values)
348+
349+
def after_step_validate_func(model):
350+
state_keys = [WEIGHT_KEY, FIRST_ORDER_KEY, SECOND_ORDER_KEY]
330351
helper_funcs = helper_funcs_mapping[api_type]
331352
optim_state_values = helper_funcs["create_random_values"](
332-
model, optim_keys, group, use_cuda=offload_device == OffloadDeviceEnum.none)
353+
model, state_keys, group, grad_dtype=dtype, use_cuda=offload_device == OffloadDeviceEnum.none)
333354
helper_funcs["set_param_values_with_dict"](model, optim_state_values)
334355
helper_funcs["validate_param_values_with_dict"](model, optim_state_values)
335356

336-
run_fragmented_model(model, config_dict, hidden_dim, dtype, lambda _: None, validate_func)
357+
run_fragmented_model(model, config_dict, hidden_dim, dtype, after_bwd_validate_func, after_step_validate_func)

0 commit comments

Comments
 (0)