@@ -1227,3 +1227,98 @@ def extra_repr(self) -> str:
12271227 s += f", tp_size={ self .tp_size } "
12281228 s += f", reduce_results={ self .reduce_results } "
12291229 return s
1230+
1231+
1232+ class QKVCrossParallelLinear (torch .nn .Module ):
1233+
1234+ def __init__ (self ,
1235+ hidden_size : int ,
1236+ head_size : int ,
1237+ total_num_heads : int ,
1238+ total_num_kv_heads : Optional [int ] = None ,
1239+ bias : bool = True ,
1240+ skip_bias_add : bool = False ,
1241+ params_dtype : Optional [torch .dtype ] = None ,
1242+ quant_config : Optional [QuantizationConfig ] = None ,
1243+ prefix : str = "" ):
1244+ super ().__init__ ()
1245+ # Empty placeholders for loading as a single module.
1246+ self .weight = torch .nn .Parameter ()
1247+ set_weight_attrs (self .weight , {
1248+ "weight_loader" : self .weight_loader_weight ,
1249+ })
1250+ # Use a dictionary to avoid submodules parameters auto-registration:
1251+ # drop-in replacement for a `QKVParallelLinear` module.
1252+ self .proj = dict ()
1253+ self .proj ["q_proj_decoder" ] = ColumnParallelLinear (
1254+ input_size = hidden_size ,
1255+ output_size = total_num_heads * head_size ,
1256+ bias = bias ,
1257+ quant_config = quant_config ,
1258+ skip_bias_add = skip_bias_add ,
1259+ params_dtype = params_dtype ,
1260+ prefix = f"{ prefix } .q_proj_decoder" )
1261+
1262+ self .proj ["kv_proj_encoder" ] = QKVParallelLinear (
1263+ hidden_size = hidden_size ,
1264+ head_size = head_size ,
1265+ total_num_heads = 0 ,
1266+ total_num_kv_heads = total_num_kv_heads ,
1267+ bias = bias ,
1268+ quant_config = quant_config ,
1269+ skip_bias_add = skip_bias_add ,
1270+ params_dtype = params_dtype ,
1271+ prefix = f"{ prefix } .kv_proj_encoder" )
1272+
1273+ # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1274+ self .kv_size = self .kv_proj_encoder .num_kv_heads * head_size
1275+
1276+ if bias :
1277+ self .bias = torch .nn .Parameter ()
1278+ set_weight_attrs (self .bias , {
1279+ "weight_loader" : self .weight_loader_bias ,
1280+ })
1281+
1282+ @property
1283+ def q_proj_decoder (self ):
1284+ return self .proj ["q_proj_decoder" ]
1285+
1286+ @property
1287+ def kv_proj_encoder (self ):
1288+ return self .proj ["kv_proj_encoder" ]
1289+
1290+ def forward (self , decoder_hidden_states , encoder_hidden_states ):
1291+ q , _ = self .q_proj_decoder (decoder_hidden_states )
1292+ if encoder_hidden_states is None :
1293+ # Encoder KV already cached.
1294+ k = None
1295+ v = None
1296+ else :
1297+ # Prefill phase, encoder KV cached here.
1298+ kv_enc , _ = self .kv_proj_encoder (encoder_hidden_states )
1299+ # Split kv in half
1300+ k , v = kv_enc .split (self .kv_size , dim = - 1 )
1301+ return q , k , v
1302+
1303+ def weight_loader_weight (self ,
1304+ param : torch .nn .Parameter ,
1305+ loaded_weight : torch .Tensor ,
1306+ loaded_shard_id : Optional [str ] = None ):
1307+ # NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
1308+ param = self .q_proj_decoder .weight if loaded_shard_id == "q" \
1309+ else self .kv_proj_encoder .weight
1310+ param .weight_loader (
1311+ param ,
1312+ loaded_weight ) if loaded_shard_id == "q" else param .weight_loader (
1313+ param , loaded_weight , loaded_shard_id )
1314+
1315+ def weight_loader_bias (self ,
1316+ param : torch .nn .Parameter ,
1317+ loaded_weight : torch .Tensor ,
1318+ loaded_shard_id : Optional [str ] = None ):
1319+ param = self .q_proj_decoder .bias if loaded_shard_id == "q" \
1320+ else self .kv_proj_encoder .bias
1321+ param .weight_loader (
1322+ param ,
1323+ loaded_weight ) if loaded_shard_id == "q" else param .weight_loader (
1324+ param , loaded_weight , loaded_shard_id )
0 commit comments