@@ -1097,8 +1097,7 @@ def __init__(self,
10971097
10981098 self .use_all_reduce_for_fetch_params = get_config_default (DeepSpeedZeroConfig ,
10991099 "use_all_reduce_for_fetch_params" )
1100- self .allgather_single_param = get_config_default (DeepSpeedZeroConfig ,
1101- "allgather_single_param" )
1100+ self .allgather_single_param = get_config_default (DeepSpeedZeroConfig , "allgather_single_param" )
11021101 if _ds_config is not None :
11031102 self .use_all_reduce_for_fetch_params = _ds_config .zero_config .use_all_reduce_for_fetch_params
11041103 self .allgather_single_param = _ds_config .zero_config .allgather_single_param
@@ -1315,7 +1314,8 @@ def all_gather_coalesced(params: Iterable[Parameter],
13151314 for param in params :
13161315 buffer_size = math .ceil (param .ds_numel / world_size ) * world_size
13171316 if use_secondary_tensor :
1318- buffer_size = param .ds_secondary_tensor .shape [0 ] * world_size #make sure out is appropriately sized
1317+ buffer_size = param .ds_secondary_tensor .shape [
1318+ 0 ] * world_size #make sure out is appropriately sized
13191319
13201320 param_ds_tensor = param .ds_secondary_tensor if use_secondary_tensor else param .ds_tensor
13211321
@@ -1339,7 +1339,8 @@ def all_gather_coalesced(params: Iterable[Parameter],
13391339 )
13401340
13411341 if original_dtype == allgather_dtype :
1342- param .data = param_buffer .narrow (0 , 0 , param .ds_numel ).view (param .ds_shape ).to (param .device )
1342+ param .data = param_buffer .narrow (0 , 0 ,
1343+ param .ds_numel ).view (param .ds_shape ).to (param .device )
13431344 handles .append (AllGatherHandle (handle , param ))
13441345 else :
13451346 # This case is complicated:
@@ -1355,7 +1356,8 @@ def all_gather_coalesced(params: Iterable[Parameter],
13551356 # In theory, this path could be consolidated with the case where
13561357 # (original_dtype == allgather_dtype), but because it changes the
13571358 # state transition of DeepSpeed parameters, we keep it separate for safety.
1358- handles .append (AllGatherHandle (handle ,
1359+ handles .append (
1360+ AllGatherHandle (handle ,
13591361 param ,
13601362 param_buffer = param_buffer ,
13611363 original_dtype = original_dtype ))
@@ -1375,7 +1377,7 @@ def all_gather_coalesced(params: Iterable[Parameter],
13751377 requires_grad = False ,
13761378 )
13771379 quant_handle = _dist_allgather_fn (scales .to (get_accelerator ().current_device_name ()),
1378- quant_scale_buffer , ds_process_group )
1380+ quant_scale_buffer , ds_process_group )
13791381 quant_info = QuantizationInfo ()
13801382 quant_info .quantized_param = param_buffer .narrow (0 , 0 , param .ds_numel ).view (param .ds_shape ).to (
13811383 param .device )
@@ -2015,8 +2017,8 @@ def _allgather_params(self, param_list, hierarchy=0):
20152017 scale_size = param .ds_tensor .ds_quant_scale .numel ()
20162018 scale_tensor_size = scale_size * self .num_partitions
20172019 flat_scale_tensor = torch .empty (scale_tensor_size ,
2018- dtype = param .ds_tensor .ds_quant_scale .dtype ,
2019- device = self .local_device )
2020+ dtype = param .ds_tensor .ds_quant_scale .dtype ,
2021+ device = self .local_device )
20202022 flat_scale_tensor .requires_grad = False
20212023
20222024 scale_partitions = []
@@ -2027,9 +2029,9 @@ def _allgather_params(self, param_list, hierarchy=0):
20272029 scale_partitions [i ].copy_ (param .ds_tensor .ds_quant_scale .data )
20282030
20292031 dist .all_gather_into_tensor (flat_tensor ,
2030- partitions [self .get_partition_rank ()],
2031- group = self .get_partition_dp_group (param ),
2032- async_op = False )
2032+ partitions [self .get_partition_rank ()],
2033+ group = self .get_partition_dp_group (param ),
2034+ async_op = False )
20332035
20342036 if hasattr (param , 'ds_quant_scale' ):
20352037 dist .all_gather (flat_scale_tensor ,
@@ -2077,7 +2079,7 @@ def _allgather_params(self, param_list, hierarchy=0):
20772079 param_scale_numel = param .ds_tensor .ds_quant_scale .ds_numel
20782080
20792081 scale_partitions [i ].narrow (0 , offset ,
2080- param_scale_numel ).copy_ (param .ds_tensor .ds_quant_scale .data )
2082+ param_scale_numel ).copy_ (param .ds_tensor .ds_quant_scale .data )
20812083
20822084 offset += param_scale_numel
20832085
0 commit comments