Skip to content

Commit 30814fa

Browse files
committed
adaptor func _allgather_params
Signed-off-by: aeeeeeep <aeeeeeep@proton.me>
1 parent 77a51f7 commit 30814fa

File tree

1 file changed

+101
-54
lines changed

1 file changed

+101
-54
lines changed

deepspeed/runtime/zero/partition_parameters.py

Lines changed: 101 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1994,78 +1994,125 @@ def _allgather_params(self, param_list, hierarchy=0):
19941994
if len(param_list) == 0:
19951995
return
19961996

1997-
partition_size = sum([param.ds_tensor.ds_numel for param in param_list])
1997+
if self.allgather_single_param:
1998+
for param in param_list:
1999+
partition_size = param.ds_tensor.ds_numel
2000+
tensor_size = partition_size * self.num_partitions
19982001

1999-
tensor_size = partition_size * self.num_partitions
2000-
flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype, device=self.local_device)
2001-
partitions = []
2002-
for i in range(self.num_partitions):
2003-
start = partition_size * i
2002+
flat_tensor = torch.empty(tensor_size, dtype=param.ds_tensor.dtype, device=self.local_device)
2003+
flat_tensor.requires_grad = False
2004+
2005+
partitions = []
2006+
for i in range(self.num_partitions):
2007+
start = partition_size * i
2008+
partitions.append(flat_tensor.narrow(0, start, partition_size))
2009+
2010+
if i == self.get_partition_rank():
2011+
partitioned_tensor.copy_(param.ds_tensor.data)
20042012

2005-
partitions.append(flat_tensor.narrow(0, start, partition_size))
2013+
if hasattr(param, 'ds_quant_scale'):
2014+
scale_size = param.ds_tensor.ds_quant_scale.numel()
2015+
scale_tensor_size = scale_size * self.num_partitions
2016+
flat_scale_tensor = torch.empty(scale_tensor_size,
2017+
dtype=param.ds_tensor.ds_quant_scale.dtype,
2018+
device=self.local_device)
2019+
flat_scale_tensor.requires_grad = False
20062020

2007-
if i == self.get_partition_rank():
2008-
offset = 0
2009-
for param in param_list:
2010-
param_numel = param.ds_tensor.ds_numel
2021+
scale_partitions = []
2022+
for i in range(self.num_partitions):
2023+
start = scale_size * i
2024+
scale_partitions.append(flat_scale_tensor.narrow(0, start, scale_size))
2025+
if i == self.get_partition_rank():
2026+
scale_partitions[i].copy_(param.ds_tensor.ds_quant_scale.data)
2027+
2028+
dist.all_gather_into_tensor(flat_tensor,
2029+
partitions[self.get_partition_rank()],
2030+
group=self.get_partition_dp_group(param),
2031+
async_op=False)
2032+
2033+
if hasattr(param, 'ds_quant_scale'):
2034+
dist.all_gather(flat_scale_tensor,
2035+
param.ds_tensor.ds_quant_scale,
2036+
group=self.get_partition_dp_group(param),
2037+
async_op=False)
20112038

2012-
partitions[i].narrow(0, offset, param_numel).copy_(param.ds_tensor.data)
2039+
param.data = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape).data
20132040

2014-
offset += param_numel
2041+
if hasattr(param, 'ds_quant_scale'):
2042+
param.data = self.quantizer_module.dequantize(param.data, flat_scale_tensor)
2043+
else:
2044+
partition_size = sum([param.ds_tensor.ds_numel for param in param_list])
20152045

2016-
if hasattr(param_list[0], 'ds_quant_scale'):
2017-
scale_size = sum([param.ds_tensor.ds_quant_scale.numel() for param in param_list])
2018-
scale_tensor_size = scale_size * self.world_size
2019-
flat_scale_tensor = torch.empty(scale_tensor_size,
2020-
dtype=param_list[0].ds_tensor.ds_quant_scale.dtype,
2021-
device=self.local_device)
2022-
scale_partitions = []
2023-
for i in range(self.world_size):
2024-
start = scale_tensor_size * i
2025-
scale_partitions.append(flat_scale_tensor.narrow(0, start, scale_tensor_size))
2026-
if i == self.rank:
2046+
tensor_size = partition_size * self.num_partitions
2047+
flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype, device=self.local_device)
2048+
partitions = []
2049+
for i in range(self.num_partitions):
2050+
start = partition_size * i
2051+
2052+
partitions.append(flat_tensor.narrow(0, start, partition_size))
2053+
2054+
if i == self.get_partition_rank():
20272055
offset = 0
20282056
for param in param_list:
2029-
param_scale_numel = param.ds_tensor.ds_quant_scale.ds_numel
2057+
param_numel = param.ds_tensor.ds_numel
20302058

2031-
scale_partitions[i].narrow(0, offset,
2032-
param_scale_numel).copy_(param.ds_tensor.ds_quant_scale.data)
2059+
partitions[i].narrow(0, offset, param_numel).copy_(param.ds_tensor.data)
20332060

2034-
offset += param_scale_numel
2061+
offset += param_numel
20352062

2036-
dist.all_gather_into_tensor(flat_tensor,
2037-
partitions[self.get_partition_rank()],
2038-
group=self.get_partition_dp_group(param),
2039-
async_op=False)
2040-
if hasattr(param_list[0], 'ds_quant_scale'):
2041-
dist.all_gather(flat_scale_tensor,
2042-
param_list[0].ds_quant_scale,
2043-
group=self.get_partition_dp_group(param),
2044-
async_op=False)
2045-
param_offset = 0
2063+
if hasattr(param_list[0], 'ds_quant_scale'):
2064+
scale_size = sum([param.ds_tensor.ds_quant_scale.numel() for param in param_list])
2065+
scale_tensor_size = scale_size * self.world_size
2066+
flat_scale_tensor = torch.empty(scale_tensor_size,
2067+
dtype=param_list[0].ds_tensor.ds_quant_scale.dtype,
2068+
device=self.local_device)
2069+
scale_partitions = []
2070+
for i in range(self.world_size):
2071+
start = scale_tensor_size * i
2072+
scale_partitions.append(flat_scale_tensor.narrow(0, start, scale_tensor_size))
2073+
if i == self.rank:
2074+
offset = 0
2075+
for param in param_list:
2076+
param_scale_numel = param.ds_tensor.ds_quant_scale.ds_numel
2077+
2078+
scale_partitions[i].narrow(0, offset,
2079+
param_scale_numel).copy_(param.ds_tensor.ds_quant_scale.data)
2080+
2081+
offset += param_scale_numel
2082+
2083+
dist.all_gather_into_tensor(flat_tensor,
2084+
partitions[self.get_partition_rank()],
2085+
group=self.get_partition_dp_group(param),
2086+
async_op=False)
2087+
if hasattr(param_list[0], 'ds_quant_scale'):
2088+
dist.all_gather(flat_scale_tensor,
2089+
param_list[0].ds_quant_scale,
2090+
group=self.get_partition_dp_group(param),
2091+
async_op=False)
2092+
param_offset = 0
20462093

2047-
for param in param_list:
2048-
param_partition_size = param.ds_tensor.ds_numel
2049-
param_size = param.ds_numel
2050-
replicated_tensor = torch.empty(param.ds_shape, dtype=param.ds_tensor.dtype, device=self.local_device)
2094+
for param in param_list:
2095+
param_partition_size = param.ds_tensor.ds_numel
2096+
param_size = param.ds_numel
2097+
replicated_tensor = torch.empty(param.ds_shape, dtype=param.ds_tensor.dtype, device=self.local_device)
20512098

2052-
for i in range(self.num_partitions):
2099+
for i in range(self.num_partitions):
20532100

2054-
start = i * partition_size
2101+
start = i * partition_size
20552102

2056-
param_start = i * param_partition_size
2103+
param_start = i * param_partition_size
20572104

2058-
if param_start < param_size:
2059-
numel_to_copy = min(param_size - param_start, param_partition_size)
2105+
if param_start < param_size:
2106+
numel_to_copy = min(param_size - param_start, param_partition_size)
20602107

2061-
part_to_copy = partitions[i].narrow(0, param_offset, numel_to_copy)
2108+
part_to_copy = partitions[i].narrow(0, param_offset, numel_to_copy)
20622109

2063-
replicated_tensor.view(-1).narrow(0, param_start, numel_to_copy).copy_(part_to_copy)
2064-
#param_offset += param.data.numel()
2065-
param_offset += param.ds_tensor.ds_numel
2066-
if hasattr(param_list[0], 'ds_quant_scale'):
2067-
replicated_tensor = self.quantizer_module.dequantize(replicated_tensor, flat_scale_tensor)
2068-
param.data = replicated_tensor.data
2110+
replicated_tensor.view(-1).narrow(0, param_start, numel_to_copy).copy_(part_to_copy)
2111+
#param_offset += param.data.numel()
2112+
param_offset += param.ds_tensor.ds_numel
2113+
if hasattr(param_list[0], 'ds_quant_scale'):
2114+
replicated_tensor = self.quantizer_module.dequantize(replicated_tensor, flat_scale_tensor)
2115+
param.data = replicated_tensor.data
20692116

20702117
return None
20712118

0 commit comments

Comments
 (0)