Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def swiglu(x, y=None):
x, y = paddle.chunk(x, chunks=2, axis=-1)
return F.silu(x) * y

try:
from paddle.incubate.nn.functional import fused_partial_rope
except ImportError:
fused_partial_rope = None


__all__ = [
"DeepseekV2LMHead",
Expand Down Expand Up @@ -1089,7 +1094,7 @@ def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:


@to_static(backend="CINN")
def qkv_pre_process(
def qkv_pre_process_no_fuse(
q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids
):
bsz, q_len, _ = q.shape
Expand Down Expand Up @@ -1125,6 +1130,50 @@ def qkv_pre_process(
return query_states, key_states, value_states


@to_static(backend="CINN")
def rearrange_kv(kv, k_pe, qk_nope_head_dim, num_heads):
k_nope = kv[..., :qk_nope_head_dim]
value_states = kv[..., qk_nope_head_dim:]

k_pe = k_pe.expand([k_pe.shape[0], k_pe.shape[1], num_heads, k_pe.shape[3]])
key_states = paddle.concat([k_nope, k_pe], axis=-1)

return key_states, value_states


def qkv_pre_process(
q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids
):
if (fused_partial_rope is None) or (position_ids is not None):
return qkv_pre_process_no_fuse(
q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids
)

bsz, q_len, _ = q.shape

target_query_shape = [0, 0, num_heads, q_head_dim]
target_key_value_shape = [0, 0, num_heads, qk_nope_head_dim + v_head_dim]

q = q.reshape(shape=target_query_shape)
kv = kv.reshape(shape=target_key_value_shape)
k_pe = k_pe.reshape([-1, q_len, 1, qk_rope_head_dim])

value_states = kv[..., qk_nope_head_dim:]

kv_seq_len = value_states.shape[1]

cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
cos = cos[None, :, None, :]
sin = sin[None, :, None, :]

query_states = fused_partial_rope(q, cos, sin)
k_pe = fused_partial_rope(k_pe, cos, sin)

key_states, value_states = rearrange_kv(kv, k_pe, qk_nope_head_dim, num_heads)

return query_states, key_states, value_states


def manul_fwd(
q_init,
kv_init,
Expand Down
Loading