@@ -1229,7 +1229,7 @@ def extra_repr(self) -> str:
12291229 return s
12301230
12311231
1232- class QKVCrossParallelLinear (torch . nn . Module ):
1232+ class QKVCrossParallelLinear (LinearBase ):
12331233
12341234 def __init__ (self ,
12351235 hidden_size : int ,
@@ -1241,12 +1241,26 @@ def __init__(self,
12411241 params_dtype : Optional [torch .dtype ] = None ,
12421242 quant_config : Optional [QuantizationConfig ] = None ,
12431243 prefix : str = "" ):
1244- super ().__init__ ()
1244+ # input_size and output_size are not used, just for alignment
1245+ input_size = hidden_size
1246+ output_size = (total_num_heads + (total_num_kv_heads or 0 )) * head_size
1247+ super ().__init__ (input_size = input_size ,
1248+ output_size = output_size ,
1249+ skip_bias_add = skip_bias_add ,
1250+ params_dtype = params_dtype ,
1251+ quant_config = quant_config ,
1252+ prefix = prefix )
1253+
12451254 # Empty placeholders for loading as a single module.
1246- self .weight = torch .nn .Parameter ()
1247- set_weight_attrs (self .weight , {
1248- "weight_loader" : self .weight_loader_weight ,
1249- })
1255+ placeholder_size = 0
1256+ quant_method = quant_config .get_quant_method (self , prefix = prefix )
1257+ quant_method .create_weights (self ,
1258+ placeholder_size , [placeholder_size ],
1259+ placeholder_size ,
1260+ placeholder_size ,
1261+ self .params_dtype ,
1262+ weight_loader = self .weight_loader_weight )
1263+
12501264 # Use a dictionary to avoid submodules parameters auto-registration:
12511265 # drop-in replacement for a `QKVParallelLinear` module.
12521266 self .proj = dict ()
@@ -1321,4 +1335,4 @@ def weight_loader_bias(self,
13211335 param .weight_loader (
13221336 param ,
13231337 loaded_weight ) if loaded_shard_id == "q" else param .weight_loader (
1324- param , loaded_weight , loaded_shard_id )
1338+ param , loaded_weight , loaded_shard_id )
0 commit comments