Skip to content

Commit c943e0e

Browse files
committed
format
Signed-off-by: aeeeeeep <aeeeeeep@proton.me>
1 parent d5f5525 commit c943e0e

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

deepspeed/runtime/zero/partition_parameters.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,8 +1097,7 @@ def __init__(self,
10971097

10981098
self.use_all_reduce_for_fetch_params = get_config_default(DeepSpeedZeroConfig,
10991099
"use_all_reduce_for_fetch_params")
1100-
self.allgather_single_param = get_config_default(DeepSpeedZeroConfig,
1101-
"allgather_single_param")
1100+
self.allgather_single_param = get_config_default(DeepSpeedZeroConfig, "allgather_single_param")
11021101
if _ds_config is not None:
11031102
self.use_all_reduce_for_fetch_params = _ds_config.zero_config.use_all_reduce_for_fetch_params
11041103
self.allgather_single_param = _ds_config.zero_config.allgather_single_param
@@ -1315,7 +1314,8 @@ def all_gather_coalesced(params: Iterable[Parameter],
13151314
for param in params:
13161315
buffer_size = math.ceil(param.ds_numel / world_size) * world_size
13171316
if use_secondary_tensor:
1318-
buffer_size = param.ds_secondary_tensor.shape[0] * world_size #make sure out is appropriately sized
1317+
buffer_size = param.ds_secondary_tensor.shape[
1318+
0] * world_size #make sure out is appropriately sized
13191319

13201320
param_ds_tensor = param.ds_secondary_tensor if use_secondary_tensor else param.ds_tensor
13211321

@@ -1339,7 +1339,8 @@ def all_gather_coalesced(params: Iterable[Parameter],
13391339
)
13401340

13411341
if original_dtype == allgather_dtype:
1342-
param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device)
1342+
param.data = param_buffer.narrow(0, 0,
1343+
param.ds_numel).view(param.ds_shape).to(param.device)
13431344
handles.append(AllGatherHandle(handle, param))
13441345
else:
13451346
# This case is complicated:
@@ -1355,7 +1356,8 @@ def all_gather_coalesced(params: Iterable[Parameter],
13551356
# In theory, this path could be consolidated with the case where
13561357
# (original_dtype == allgather_dtype), but because it changes the
13571358
# state transition of DeepSpeed parameters, we keep it separate for safety.
1358-
handles.append(AllGatherHandle(handle,
1359+
handles.append(
1360+
AllGatherHandle(handle,
13591361
param,
13601362
param_buffer=param_buffer,
13611363
original_dtype=original_dtype))
@@ -1375,7 +1377,7 @@ def all_gather_coalesced(params: Iterable[Parameter],
13751377
requires_grad=False,
13761378
)
13771379
quant_handle = _dist_allgather_fn(scales.to(get_accelerator().current_device_name()),
1378-
quant_scale_buffer, ds_process_group)
1380+
quant_scale_buffer, ds_process_group)
13791381
quant_info = QuantizationInfo()
13801382
quant_info.quantized_param = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(
13811383
param.device)
@@ -2015,8 +2017,8 @@ def _allgather_params(self, param_list, hierarchy=0):
20152017
scale_size = param.ds_tensor.ds_quant_scale.numel()
20162018
scale_tensor_size = scale_size * self.num_partitions
20172019
flat_scale_tensor = torch.empty(scale_tensor_size,
2018-
dtype=param.ds_tensor.ds_quant_scale.dtype,
2019-
device=self.local_device)
2020+
dtype=param.ds_tensor.ds_quant_scale.dtype,
2021+
device=self.local_device)
20202022
flat_scale_tensor.requires_grad = False
20212023

20222024
scale_partitions = []
@@ -2027,9 +2029,9 @@ def _allgather_params(self, param_list, hierarchy=0):
20272029
scale_partitions[i].copy_(param.ds_tensor.ds_quant_scale.data)
20282030

20292031
dist.all_gather_into_tensor(flat_tensor,
2030-
partitions[self.get_partition_rank()],
2031-
group=self.get_partition_dp_group(param),
2032-
async_op=False)
2032+
partitions[self.get_partition_rank()],
2033+
group=self.get_partition_dp_group(param),
2034+
async_op=False)
20332035

20342036
if hasattr(param, 'ds_quant_scale'):
20352037
dist.all_gather(flat_scale_tensor,
@@ -2077,7 +2079,7 @@ def _allgather_params(self, param_list, hierarchy=0):
20772079
param_scale_numel = param.ds_tensor.ds_quant_scale.ds_numel
20782080

20792081
scale_partitions[i].narrow(0, offset,
2080-
param_scale_numel).copy_(param.ds_tensor.ds_quant_scale.data)
2082+
param_scale_numel).copy_(param.ds_tensor.ds_quant_scale.data)
20812083

20822084
offset += param_scale_numel
20832085

0 commit comments

Comments
 (0)