From 2fd8f9db3d591faee9b6f10789b197ae5de52603 Mon Sep 17 00:00:00 2001 From: Jiahao Li Date: Tue, 10 Dec 2024 17:42:59 +0800 Subject: [PATCH] Fix Qwen2VL mrope for transformers 4.47.0 --- src/liger_kernel/ops/qwen2vl_mrope.py | 25 ++++++++++--------- .../transformers/qwen2vl_mrope.py | 4 +-- test/transformers/test_qwen2vl_mrope.py | 8 ++++-- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/liger_kernel/ops/qwen2vl_mrope.py b/src/liger_kernel/ops/qwen2vl_mrope.py index 8c2716281..103b15604 100644 --- a/src/liger_kernel/ops/qwen2vl_mrope.py +++ b/src/liger_kernel/ops/qwen2vl_mrope.py @@ -10,6 +10,7 @@ def _triton_qwen2vl_mrope( cos, sin, sl, + bs: tl.constexpr, n_qh: tl.constexpr, n_kh: tl.constexpr, hd: tl.constexpr, @@ -41,13 +42,12 @@ def _triton_qwen2vl_mrope( t_end = mrope_section_t h_end = t_end + mrope_section_h - cos_row_idx = pid % sl - t_cos = cos + cos_row_idx * hd - h_cos = t_cos + sl * hd - w_cos = h_cos + sl * hd - t_sin = sin + cos_row_idx * hd - h_sin = t_sin + sl * hd - w_sin = h_sin + sl * hd + t_cos = cos + pid * hd + h_cos = t_cos + bs * sl * hd + w_cos = h_cos + bs * sl * hd + t_sin = sin + pid * hd + h_sin = t_sin + bs * sl * hd + w_sin = h_sin + bs * sl * hd cos_offsets = tl.arange(0, pad_hd // 2) t_mask = cos_offsets < t_end @@ -151,6 +151,7 @@ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section): cos, sin, seq_len, + batch_size, n_q_head, n_kv_head, head_dim, @@ -189,6 +190,7 @@ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section): cos, sin, seq_len, + batch_size, n_q_head, n_kv_head, head_dim, @@ -216,8 +218,8 @@ def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1): """ q size: (bsz, n_q_head, seq_len, head_dim) k size: (bsz, n_kv_head, seq_len, head_dim) - cos size: (3, 1, seq_len, head_dim) - sin size: (3, 1, seq_len, head_dim) + cos size: (3, bsz, seq_len, head_dim) + sin size: (3, bsz, seq_len, head_dim) """ q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section) ctx.save_for_backward(cos, sin) @@ -228,10 +230,9 @@ def backward(ctx, dq, dk): """ dq size: (bsz, n_q_head, seq_len, head_dim) dk size: (bsz, n_kv_head, seq_len, head_dim) - cos size: (3, 1, seq_len, head_dim) - sin size: (3, 1, seq_len, head_dim) + cos size: (3, bsz, seq_len, head_dim) + sin size: (3, bsz, seq_len, head_dim) """ - cos, sin = ctx.saved_tensors mrope_section = ctx.mrope_section dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section) diff --git a/src/liger_kernel/transformers/qwen2vl_mrope.py b/src/liger_kernel/transformers/qwen2vl_mrope.py index f7b8cd6e8..c271837c4 100644 --- a/src/liger_kernel/transformers/qwen2vl_mrope.py +++ b/src/liger_kernel/transformers/qwen2vl_mrope.py @@ -8,8 +8,8 @@ def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim Args: q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim). k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim). - cos (torch.Tensor): The cosine tensor of shape (3, 1, seq_len, head_dim). - sin (torch.Tensor): The sine tensor of shape (3, 1, seq_len, head_dim). + cos (torch.Tensor): The cosine tensor of shape (3, bsz, seq_len, head_dim). + sin (torch.Tensor): The sine tensor of shape (3, bsz, seq_len, head_dim). mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation. unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1. diff --git a/test/transformers/test_qwen2vl_mrope.py b/test/transformers/test_qwen2vl_mrope.py index bfc1f9ac2..239ba7784 100644 --- a/test/transformers/test_qwen2vl_mrope.py +++ b/test/transformers/test_qwen2vl_mrope.py @@ -73,7 +73,9 @@ def test_correctness( k2 = _tensor_k.clone().requires_grad_(True) # NOTE: this position ids distribution is different from the real one, just to test op correctness - pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) + pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view( + 3, bsz, seq_len + ) cos, sin = rotary_emb(k1, pos_ids) # validate forward pass @@ -130,7 +132,9 @@ def test_functional_correctness( rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) - pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) + pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view( + 3, bsz, seq_len + ) cos, sin = rotary_emb(k1, pos_ids) functional_q, functional_k = liger_qwen2vl_mrope(q1, k1, cos, sin, mrope_section)