diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index b7572c593c65..65c021d98e3b 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -106,6 +106,11 @@ def swiglu(x, y=None): x, y = paddle.chunk(x, chunks=2, axis=-1) return F.silu(x) * y +try: + from paddle.incubate.nn.functional import fused_partial_rope +except ImportError: + fused_partial_rope = None + __all__ = [ "DeepseekV2LMHead", @@ -1089,7 +1094,7 @@ def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: @to_static(backend="CINN") -def qkv_pre_process( +def qkv_pre_process_no_fuse( q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids ): bsz, q_len, _ = q.shape @@ -1125,6 +1130,50 @@ def qkv_pre_process( return query_states, key_states, value_states +@to_static(backend="CINN") +def rearrange_kv(kv, k_pe, qk_nope_head_dim, num_heads): + k_nope = kv[..., :qk_nope_head_dim] + value_states = kv[..., qk_nope_head_dim:] + + k_pe = k_pe.expand([k_pe.shape[0], k_pe.shape[1], num_heads, k_pe.shape[3]]) + key_states = paddle.concat([k_nope, k_pe], axis=-1) + + return key_states, value_states + + +def qkv_pre_process( + q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids +): + if (fused_partial_rope is None) or (position_ids is not None): + return qkv_pre_process_no_fuse( + q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids + ) + + bsz, q_len, _ = q.shape + + target_query_shape = [0, 0, num_heads, q_head_dim] + target_key_value_shape = [0, 0, num_heads, qk_nope_head_dim + v_head_dim] + + q = q.reshape(shape=target_query_shape) + kv = kv.reshape(shape=target_key_value_shape) + k_pe = k_pe.reshape([-1, q_len, 1, qk_rope_head_dim]) + + value_states = kv[..., qk_nope_head_dim:] + + kv_seq_len = value_states.shape[1] + + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + + query_states = fused_partial_rope(q, cos, sin) + k_pe = fused_partial_rope(k_pe, cos, sin) + + key_states, value_states = rearrange_kv(kv, k_pe, qk_nope_head_dim, num_heads) + + return query_states, key_states, value_states + + def manul_fwd( q_init, kv_init,