@@ -1994,78 +1994,125 @@ def _allgather_params(self, param_list, hierarchy=0):
19941994 if len (param_list ) == 0 :
19951995 return
19961996
1997- partition_size = sum ([param .ds_tensor .ds_numel for param in param_list ])
1997+ if self .allgather_single_param :
1998+ for param in param_list :
1999+ partition_size = param .ds_tensor .ds_numel
2000+ tensor_size = partition_size * self .num_partitions
19982001
1999- tensor_size = partition_size * self .num_partitions
2000- flat_tensor = torch .empty (tensor_size , dtype = param_list [0 ].ds_tensor .dtype , device = self .local_device )
2001- partitions = []
2002- for i in range (self .num_partitions ):
2003- start = partition_size * i
2002+ flat_tensor = torch .empty (tensor_size , dtype = param .ds_tensor .dtype , device = self .local_device )
2003+ flat_tensor .requires_grad = False
2004+
2005+ partitions = []
2006+ for i in range (self .num_partitions ):
2007+ start = partition_size * i
2008+ partitions .append (flat_tensor .narrow (0 , start , partition_size ))
2009+
2010+ if i == self .get_partition_rank ():
2011+ partitioned_tensor .copy_ (param .ds_tensor .data )
20042012
2005- partitions .append (flat_tensor .narrow (0 , start , partition_size ))
2013+ if hasattr (param , 'ds_quant_scale' ):
2014+ scale_size = param .ds_tensor .ds_quant_scale .numel ()
2015+ scale_tensor_size = scale_size * self .num_partitions
2016+ flat_scale_tensor = torch .empty (scale_tensor_size ,
2017+ dtype = param .ds_tensor .ds_quant_scale .dtype ,
2018+ device = self .local_device )
2019+ flat_scale_tensor .requires_grad = False
20062020
2007- if i == self .get_partition_rank ():
2008- offset = 0
2009- for param in param_list :
2010- param_numel = param .ds_tensor .ds_numel
2021+ scale_partitions = []
2022+ for i in range (self .num_partitions ):
2023+ start = scale_size * i
2024+ scale_partitions .append (flat_scale_tensor .narrow (0 , start , scale_size ))
2025+ if i == self .get_partition_rank ():
2026+ scale_partitions [i ].copy_ (param .ds_tensor .ds_quant_scale .data )
2027+
2028+ dist .all_gather_into_tensor (flat_tensor ,
2029+ partitions [self .get_partition_rank ()],
2030+ group = self .get_partition_dp_group (param ),
2031+ async_op = False )
2032+
2033+ if hasattr (param , 'ds_quant_scale' ):
2034+ dist .all_gather (flat_scale_tensor ,
2035+ param .ds_tensor .ds_quant_scale ,
2036+ group = self .get_partition_dp_group (param ),
2037+ async_op = False )
20112038
2012- partitions [ i ] .narrow (0 , offset , param_numel ). copy_ (param .ds_tensor .data )
2039+ param . data = flat_tensor .narrow (0 , 0 , param . ds_numel ). view (param .ds_shape ) .data
20132040
2014- offset += param_numel
2041+ if hasattr (param , 'ds_quant_scale' ):
2042+ param .data = self .quantizer_module .dequantize (param .data , flat_scale_tensor )
2043+ else :
2044+ partition_size = sum ([param .ds_tensor .ds_numel for param in param_list ])
20152045
2016- if hasattr (param_list [0 ], 'ds_quant_scale' ):
2017- scale_size = sum ([param .ds_tensor .ds_quant_scale .numel () for param in param_list ])
2018- scale_tensor_size = scale_size * self .world_size
2019- flat_scale_tensor = torch .empty (scale_tensor_size ,
2020- dtype = param_list [0 ].ds_tensor .ds_quant_scale .dtype ,
2021- device = self .local_device )
2022- scale_partitions = []
2023- for i in range (self .world_size ):
2024- start = scale_tensor_size * i
2025- scale_partitions .append (flat_scale_tensor .narrow (0 , start , scale_tensor_size ))
2026- if i == self .rank :
2046+ tensor_size = partition_size * self .num_partitions
2047+ flat_tensor = torch .empty (tensor_size , dtype = param_list [0 ].ds_tensor .dtype , device = self .local_device )
2048+ partitions = []
2049+ for i in range (self .num_partitions ):
2050+ start = partition_size * i
2051+
2052+ partitions .append (flat_tensor .narrow (0 , start , partition_size ))
2053+
2054+ if i == self .get_partition_rank ():
20272055 offset = 0
20282056 for param in param_list :
2029- param_scale_numel = param .ds_tensor . ds_quant_scale .ds_numel
2057+ param_numel = param .ds_tensor .ds_numel
20302058
2031- scale_partitions [i ].narrow (0 , offset ,
2032- param_scale_numel ).copy_ (param .ds_tensor .ds_quant_scale .data )
2059+ partitions [i ].narrow (0 , offset , param_numel ).copy_ (param .ds_tensor .data )
20332060
2034- offset += param_scale_numel
2061+ offset += param_numel
20352062
2036- dist .all_gather_into_tensor (flat_tensor ,
2037- partitions [self .get_partition_rank ()],
2038- group = self .get_partition_dp_group (param ),
2039- async_op = False )
2040- if hasattr (param_list [0 ], 'ds_quant_scale' ):
2041- dist .all_gather (flat_scale_tensor ,
2042- param_list [0 ].ds_quant_scale ,
2043- group = self .get_partition_dp_group (param ),
2044- async_op = False )
2045- param_offset = 0
2063+ if hasattr (param_list [0 ], 'ds_quant_scale' ):
2064+ scale_size = sum ([param .ds_tensor .ds_quant_scale .numel () for param in param_list ])
2065+ scale_tensor_size = scale_size * self .world_size
2066+ flat_scale_tensor = torch .empty (scale_tensor_size ,
2067+ dtype = param_list [0 ].ds_tensor .ds_quant_scale .dtype ,
2068+ device = self .local_device )
2069+ scale_partitions = []
2070+ for i in range (self .world_size ):
2071+ start = scale_tensor_size * i
2072+ scale_partitions .append (flat_scale_tensor .narrow (0 , start , scale_tensor_size ))
2073+ if i == self .rank :
2074+ offset = 0
2075+ for param in param_list :
2076+ param_scale_numel = param .ds_tensor .ds_quant_scale .ds_numel
2077+
2078+ scale_partitions [i ].narrow (0 , offset ,
2079+ param_scale_numel ).copy_ (param .ds_tensor .ds_quant_scale .data )
2080+
2081+ offset += param_scale_numel
2082+
2083+ dist .all_gather_into_tensor (flat_tensor ,
2084+ partitions [self .get_partition_rank ()],
2085+ group = self .get_partition_dp_group (param ),
2086+ async_op = False )
2087+ if hasattr (param_list [0 ], 'ds_quant_scale' ):
2088+ dist .all_gather (flat_scale_tensor ,
2089+ param_list [0 ].ds_quant_scale ,
2090+ group = self .get_partition_dp_group (param ),
2091+ async_op = False )
2092+ param_offset = 0
20462093
2047- for param in param_list :
2048- param_partition_size = param .ds_tensor .ds_numel
2049- param_size = param .ds_numel
2050- replicated_tensor = torch .empty (param .ds_shape , dtype = param .ds_tensor .dtype , device = self .local_device )
2094+ for param in param_list :
2095+ param_partition_size = param .ds_tensor .ds_numel
2096+ param_size = param .ds_numel
2097+ replicated_tensor = torch .empty (param .ds_shape , dtype = param .ds_tensor .dtype , device = self .local_device )
20512098
2052- for i in range (self .num_partitions ):
2099+ for i in range (self .num_partitions ):
20532100
2054- start = i * partition_size
2101+ start = i * partition_size
20552102
2056- param_start = i * param_partition_size
2103+ param_start = i * param_partition_size
20572104
2058- if param_start < param_size :
2059- numel_to_copy = min (param_size - param_start , param_partition_size )
2105+ if param_start < param_size :
2106+ numel_to_copy = min (param_size - param_start , param_partition_size )
20602107
2061- part_to_copy = partitions [i ].narrow (0 , param_offset , numel_to_copy )
2108+ part_to_copy = partitions [i ].narrow (0 , param_offset , numel_to_copy )
20622109
2063- replicated_tensor .view (- 1 ).narrow (0 , param_start , numel_to_copy ).copy_ (part_to_copy )
2064- #param_offset += param.data.numel()
2065- param_offset += param .ds_tensor .ds_numel
2066- if hasattr (param_list [0 ], 'ds_quant_scale' ):
2067- replicated_tensor = self .quantizer_module .dequantize (replicated_tensor , flat_scale_tensor )
2068- param .data = replicated_tensor .data
2110+ replicated_tensor .view (- 1 ).narrow (0 , param_start , numel_to_copy ).copy_ (part_to_copy )
2111+ #param_offset += param.data.numel()
2112+ param_offset += param .ds_tensor .ds_numel
2113+ if hasattr (param_list [0 ], 'ds_quant_scale' ):
2114+ replicated_tensor = self .quantizer_module .dequantize (replicated_tensor , flat_scale_tensor )
2115+ param .data = replicated_tensor .data
20692116
20702117 return None
20712118
0 commit comments