@@ -106,6 +106,11 @@ def swiglu(x, y=None):
106106 x , y = paddle .chunk (x , chunks = 2 , axis = - 1 )
107107 return F .silu (x ) * y
108108
109+ try :
110+ from paddle .incubate .nn .functional import fused_partial_rope
111+ except ImportError :
112+ fused_partial_rope = None
113+
109114
110115__all__ = [
111116 "DeepseekV2LMHead" ,
@@ -1089,7 +1094,7 @@ def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
10891094
10901095
10911096@to_static (backend = "CINN" )
1092- def qkv_pre_process (
1097+ def qkv_pre_process_no_fuse (
10931098 q , kv , k_pe , rotary_emb , num_heads , q_head_dim , qk_nope_head_dim , v_head_dim , qk_rope_head_dim , position_ids
10941099):
10951100 bsz , q_len , _ = q .shape
@@ -1125,6 +1130,50 @@ def qkv_pre_process(
11251130 return query_states , key_states , value_states
11261131
11271132
1133+ @to_static (backend = "CINN" )
1134+ def rearrange_kv (kv , k_pe , qk_nope_head_dim , num_heads ):
1135+ k_nope = kv [..., :qk_nope_head_dim ]
1136+ value_states = kv [..., qk_nope_head_dim :]
1137+
1138+ k_pe = k_pe .expand ([k_pe .shape [0 ], k_pe .shape [1 ], num_heads , k_pe .shape [3 ]])
1139+ key_states = paddle .concat ([k_nope , k_pe ], axis = - 1 )
1140+
1141+ return key_states , value_states
1142+
1143+
1144+ def qkv_pre_process (
1145+ q , kv , k_pe , rotary_emb , num_heads , q_head_dim , qk_nope_head_dim , v_head_dim , qk_rope_head_dim , position_ids
1146+ ):
1147+ if (fused_partial_rope is None ) or (position_ids is not None ):
1148+ return qkv_pre_process_no_fuse (
1149+ q , kv , k_pe , rotary_emb , num_heads , q_head_dim , qk_nope_head_dim , v_head_dim , qk_rope_head_dim , position_ids
1150+ )
1151+
1152+ bsz , q_len , _ = q .shape
1153+
1154+ target_query_shape = [0 , 0 , num_heads , q_head_dim ]
1155+ target_key_value_shape = [0 , 0 , num_heads , qk_nope_head_dim + v_head_dim ]
1156+
1157+ q = q .reshape (shape = target_query_shape )
1158+ kv = kv .reshape (shape = target_key_value_shape )
1159+ k_pe = k_pe .reshape ([- 1 , q_len , 1 , qk_rope_head_dim ])
1160+
1161+ value_states = kv [..., qk_nope_head_dim :]
1162+
1163+ kv_seq_len = value_states .shape [1 ]
1164+
1165+ cos , sin = rotary_emb (value_states , seq_len = kv_seq_len )
1166+ cos = cos [None , :, None , :]
1167+ sin = sin [None , :, None , :]
1168+
1169+ query_states = fused_partial_rope (q , cos , sin )
1170+ k_pe = fused_partial_rope (k_pe , cos , sin )
1171+
1172+ key_states , value_states = rearrange_kv (kv , k_pe , qk_nope_head_dim , num_heads )
1173+
1174+ return query_states , key_states , value_states
1175+
1176+
11281177def manul_fwd (
11291178 q_init ,
11301179 kv_init ,
0 commit comments