1313
1414import deepspeed
1515from deepspeed .utils import safe_get_full_fp32_param , safe_get_full_grad , safe_get_full_optimizer_state
16- from deepspeed .utils import safe_set_full_fp32_param , safe_set_full_optimizer_state
16+ from deepspeed .utils import safe_set_full_fp32_param , safe_set_full_grad , safe_set_full_optimizer_state
1717from deepspeed .utils import safe_get_local_fp32_param , safe_get_local_grad , safe_get_local_optimizer_state
18- from deepspeed .utils import safe_set_local_fp32_param , safe_set_local_optimizer_state
18+ from deepspeed .utils import safe_set_local_fp32_param , safe_set_local_grad , safe_set_local_optimizer_state
1919from deepspeed .runtime .zero .offload_config import OffloadDeviceEnum
2020from deepspeed .ops .aio import AsyncIOBuilder
2121from deepspeed .accelerator import get_accelerator
2222
2323WEIGHT_KEY = 'weight'
2424FIRST_ORDER_KEY = 'exp_avg'
2525SECOND_ORDER_KEY = 'exp_avg_sq'
26+ GRADIENT_KEY = 'gradient'
2627
2728
2829def validate_tensor (model , api_type , opt_states ):
@@ -180,13 +181,14 @@ def test_bf16_fragments(self, frozen_weights):
180181 run_fragmented_model (model , config_dict , hidden_dim , torch .bfloat16 , validate_after_bwd , validate_after_step )
181182
182183
183- def create_random_values (model , key_list , group , use_cuda = True ):
184+ def create_random_values (model , key_list , group , grad_dtype , use_cuda = True ):
184185 param_values = {}
185186 for n , lp in model .named_parameters ():
186187 param_shape = lp .ds_shape if hasattr (lp , 'ds_id' ) else lp .shape
187188 param_values [n ] = {}
188189 for key in key_list :
189- rand_value = torch .rand (param_shape , dtype = torch .float32 , device = model .device )
190+ dtype = grad_dtype if key == GRADIENT_KEY else torch .float32
191+ rand_value = torch .rand (param_shape , dtype = dtype , device = model .device )
190192 dist .broadcast (rand_value , src = 0 , group = group )
191193 param_values [n ][key ] = rand_value
192194 return param_values
@@ -195,7 +197,9 @@ def create_random_values(model, key_list, group, use_cuda=True):
195197def set_param_values_with_dict (model , value_dict ):
196198 for n , lp in model .named_parameters ():
197199 for key , value_tensor in value_dict [n ].items ():
198- if key == WEIGHT_KEY :
200+ if key == GRADIENT_KEY :
201+ safe_set_full_grad (lp , value_tensor )
202+ elif key == WEIGHT_KEY :
199203 safe_set_full_fp32_param (lp , value_tensor )
200204 else :
201205 safe_set_full_optimizer_state (lp , value_tensor , key )
@@ -204,21 +208,25 @@ def set_param_values_with_dict(model, value_dict):
204208def validate_param_values_with_dict (model , value_dict ):
205209 for n , lp in model .named_parameters ():
206210 for key , expected_tensor in value_dict [n ].items ():
207- if key == WEIGHT_KEY :
211+ if key == GRADIENT_KEY :
212+ actual_tensor = safe_get_full_grad (lp )
213+ elif key == WEIGHT_KEY :
208214 actual_tensor = safe_get_full_fp32_param (lp )
209215 else :
210216 actual_tensor = safe_get_full_optimizer_state (lp , key )
217+
211218 assert torch .equal (expected_tensor , actual_tensor )
212219
213220
214- def create_random_values_for_local (model , key_list , group , use_cuda = True ):
221+ def create_random_values_for_local (model , key_list , group , grad_dtype , use_cuda = True ):
215222 param_values = {}
216223 for n , lp in model .named_parameters ():
217224 param_shape = lp .ds_tensor .shape
218225 param_values [n ] = {}
219226 for key in key_list :
220227 device = model .device if use_cuda else "cpu"
221- rand_value = torch .rand (param_shape , dtype = torch .float32 , device = device )
228+ dtype = grad_dtype if key == GRADIENT_KEY else torch .float32
229+ rand_value = torch .rand (param_shape , dtype = dtype , device = device )
222230 # dist.broadcast(rand_value, src=0, group=group)
223231 param_values [n ][key ] = rand_value
224232 return param_values
@@ -228,7 +236,9 @@ def set_local_param_values_with_dict(model, value_dict):
228236 for n , lp in model .named_parameters ():
229237
230238 for key , value_tensor in value_dict [n ].items ():
231- if key == WEIGHT_KEY :
239+ if key == GRADIENT_KEY :
240+ safe_set_local_grad (lp , value_tensor )
241+ elif key == WEIGHT_KEY :
232242 safe_set_local_fp32_param (lp , value_tensor )
233243 else :
234244 safe_set_local_optimizer_state (lp , value_tensor , key )
@@ -237,10 +247,13 @@ def set_local_param_values_with_dict(model, value_dict):
237247def validate_local_param_values_with_dict (model , value_dict ):
238248 for n , lp in model .named_parameters ():
239249 for key , expected_tensor in value_dict [n ].items ():
240- if key == WEIGHT_KEY :
250+ if key == GRADIENT_KEY :
251+ actual_tensor = safe_get_local_grad (lp )
252+ elif key == WEIGHT_KEY :
241253 actual_tensor = safe_get_local_fp32_param (lp )
242254 else :
243255 actual_tensor = safe_get_local_optimizer_state (lp , key )
256+
244257 assert torch .equal (expected_tensor , actual_tensor )
245258
246259
@@ -325,12 +338,20 @@ def test_zero_fragments(self, tmpdir, api_type, zero_stage, offload_device, dtyp
325338
326339 dist .barrier ()
327340
328- def validate_func (model ):
329- optim_keys = [WEIGHT_KEY , FIRST_ORDER_KEY , SECOND_ORDER_KEY ]
341+ def after_bwd_validate_func (model ):
342+ state_keys = [WEIGHT_KEY , GRADIENT_KEY ]
343+ helper_funcs = helper_funcs_mapping [api_type ]
344+ optim_state_values = helper_funcs ["create_random_values" ](
345+ model , state_keys , group , grad_dtype = dtype , use_cuda = offload_device == OffloadDeviceEnum .none )
346+ helper_funcs ["set_param_values_with_dict" ](model , optim_state_values )
347+ helper_funcs ["validate_param_values_with_dict" ](model , optim_state_values )
348+
349+ def after_step_validate_func (model ):
350+ state_keys = [WEIGHT_KEY , FIRST_ORDER_KEY , SECOND_ORDER_KEY ]
330351 helper_funcs = helper_funcs_mapping [api_type ]
331352 optim_state_values = helper_funcs ["create_random_values" ](
332- model , optim_keys , group , use_cuda = offload_device == OffloadDeviceEnum .none )
353+ model , state_keys , group , grad_dtype = dtype , use_cuda = offload_device == OffloadDeviceEnum .none )
333354 helper_funcs ["set_param_values_with_dict" ](model , optim_state_values )
334355 helper_funcs ["validate_param_values_with_dict" ](model , optim_state_values )
335356
336- run_fragmented_model (model , config_dict , hidden_dim , dtype , lambda _ : None , validate_func )
357+ run_fragmented_model (model , config_dict , hidden_dim , dtype , after_bwd_validate_func , after_step_validate_func )
0 commit comments