Skip to content

Commit

Permalink
Fix Qwen2VL mrope for transformers 4.47.0
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus committed Dec 10, 2024
1 parent 62a3c7d commit 2fd8f9d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 16 deletions.
25 changes: 13 additions & 12 deletions src/liger_kernel/ops/qwen2vl_mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def _triton_qwen2vl_mrope(
cos,
sin,
sl,
bs: tl.constexpr,
n_qh: tl.constexpr,
n_kh: tl.constexpr,
hd: tl.constexpr,
Expand Down Expand Up @@ -41,13 +42,12 @@ def _triton_qwen2vl_mrope(
t_end = mrope_section_t
h_end = t_end + mrope_section_h

cos_row_idx = pid % sl
t_cos = cos + cos_row_idx * hd
h_cos = t_cos + sl * hd
w_cos = h_cos + sl * hd
t_sin = sin + cos_row_idx * hd
h_sin = t_sin + sl * hd
w_sin = h_sin + sl * hd
t_cos = cos + pid * hd
h_cos = t_cos + bs * sl * hd
w_cos = h_cos + bs * sl * hd
t_sin = sin + pid * hd
h_sin = t_sin + bs * sl * hd
w_sin = h_sin + bs * sl * hd

cos_offsets = tl.arange(0, pad_hd // 2)
t_mask = cos_offsets < t_end
Expand Down Expand Up @@ -151,6 +151,7 @@ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
cos,
sin,
seq_len,
batch_size,
n_q_head,
n_kv_head,
head_dim,
Expand Down Expand Up @@ -189,6 +190,7 @@ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
cos,
sin,
seq_len,
batch_size,
n_q_head,
n_kv_head,
head_dim,
Expand Down Expand Up @@ -216,8 +218,8 @@ def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""
q size: (bsz, n_q_head, seq_len, head_dim)
k size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (3, 1, seq_len, head_dim)
sin size: (3, 1, seq_len, head_dim)
cos size: (3, bsz, seq_len, head_dim)
sin size: (3, bsz, seq_len, head_dim)
"""
q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
ctx.save_for_backward(cos, sin)
Expand All @@ -228,10 +230,9 @@ def backward(ctx, dq, dk):
"""
dq size: (bsz, n_q_head, seq_len, head_dim)
dk size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (3, 1, seq_len, head_dim)
sin size: (3, 1, seq_len, head_dim)
cos size: (3, bsz, seq_len, head_dim)
sin size: (3, bsz, seq_len, head_dim)
"""

cos, sin = ctx.saved_tensors
mrope_section = ctx.mrope_section
dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
Expand Down
4 changes: 2 additions & 2 deletions src/liger_kernel/transformers/qwen2vl_mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
Args:
q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
cos (torch.Tensor): The cosine tensor of shape (3, 1, seq_len, head_dim).
sin (torch.Tensor): The sine tensor of shape (3, 1, seq_len, head_dim).
cos (torch.Tensor): The cosine tensor of shape (3, bsz, seq_len, head_dim).
sin (torch.Tensor): The sine tensor of shape (3, bsz, seq_len, head_dim).
mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation.
unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
Expand Down
8 changes: 6 additions & 2 deletions test/transformers/test_qwen2vl_mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def test_correctness(
k2 = _tensor_k.clone().requires_grad_(True)

# NOTE: this position ids distribution is different from the real one, just to test op correctness
pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1)
pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view(
3, bsz, seq_len
)
cos, sin = rotary_emb(k1, pos_ids)

# validate forward pass
Expand Down Expand Up @@ -130,7 +132,9 @@ def test_functional_correctness(

rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device)

pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1)
pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view(
3, bsz, seq_len
)
cos, sin = rotary_emb(k1, pos_ids)

functional_q, functional_k = liger_qwen2vl_mrope(q1, k1, cos, sin, mrope_section)
Expand Down

0 comments on commit 2fd8f9d

Please sign in to comment.