From ec5e0bab733e0b6281c23f67e90d164798790695 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 17 Sep 2025 15:38:51 +0800 Subject: [PATCH 01/10] init Signed-off-by: Isotr0py --- .../layers/rotary_embedding/mrope.py | 116 +++++++++++++++++- 1 file changed, 110 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index ef61dbc1a5ab..3869e10f7b11 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -123,6 +123,110 @@ def _triton_qwen2vl_mrope_forward( tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) +@triton.jit +def _triton_interleaved_mrope_forward( + q_ptr, + k_ptr, + cos, + sin, + num_tokens, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + rd: tl.constexpr, + pad_n_qh: tl.constexpr, + pad_n_kh: tl.constexpr, + pad_hd: tl.constexpr, + mrope_section_t: tl.constexpr, + mrope_section_h: tl.constexpr, +): + pid = tl.program_id(0) + # locate start address + q_ptr = q_ptr + pid * (n_qh * hd) + k_ptr = k_ptr + pid * (n_kh * hd) + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + # Note: cos and sin now have shape (3, num_tokens, head_dim // 2) + + t_end = mrope_section_t + h_end = t_end + mrope_section_h + + # Updated stride calculation for half head_dim + half_rd = rd // 2 + t_cos = cos + pid * half_rd + h_cos = t_cos + num_tokens * half_rd + w_cos = h_cos + num_tokens * half_rd + t_sin = sin + pid * half_rd + h_sin = t_sin + num_tokens * half_rd + w_sin = h_sin + num_tokens * half_rd + + # Updated offsets for half head_dim + cos_offsets = tl.arange(0, pad_hd // 2) * 2 + t_mask = cos_offsets < t_end + h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) + w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd) + + t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0) + h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0) + w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0) + t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0) + h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0) + w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0) + + cos_row = t_cos_row + h_cos_row + w_cos_row + sin_row = t_sin_row + h_sin_row + w_sin_row + + # #################################################################### + # Load the left and right half of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # left half of the head + first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange( + 0, pad_hd // 2)[None, :] + first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange( + 0, pad_hd // 2)[None, :] + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange( + 0, pad_hd // 2)[None, :] < rd // 2) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange( + 0, pad_hd // 2)[None, :] < rd // 2) + + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, + mask=first_q_mask, + other=0).to(sin_row.dtype) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, + mask=first_k_mask, + other=0).to(sin_row.dtype) + + # right half of the head + second_half_q_offsets = first_half_q_offsets + (rd // 2) + second_half_k_offsets = first_half_k_offsets + (rd // 2) + second_q_mask = first_q_mask + second_k_mask = first_k_mask + + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, + mask=second_q_mask, + other=0).to(sin_row.dtype) + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, + mask=second_k_mask, + other=0).to(sin_row.dtype) + + # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] + # Since cos and sin are now half-size, + # we use the same cos_row and sin_row for both halves + new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + + def triton_mrope( q: torch.Tensor, k: torch.Tensor, @@ -131,6 +235,7 @@ def triton_mrope( mrope_section: list[int], head_size: int, rotary_dim: int, + mrope_interleaved: bool, ) -> tuple[torch.Tensor, torch.Tensor]: """Qwen2VL mrope kernel. @@ -158,7 +263,9 @@ def triton_mrope( cos = cos.contiguous() sin = sin.contiguous() - _triton_qwen2vl_mrope_forward[(n_row, )]( + kernel = (_triton_interleaved_mrope_forward + if mrope_interleaved else _triton_qwen2vl_mrope_forward) + kernel[(n_row, )]( q, k, cos, @@ -201,7 +308,7 @@ def __init__( is_neox_style: bool, dtype: torch.dtype, mrope_section: Optional[list[int]] = None, - mrope_interleaved: Optional[bool] = False, + mrope_interleaved: bool = False, ) -> None: # In Qwen2.5-VL, the maximum index value is related to the duration of # the input video. We enlarge max_position_embeddings to 4 times to get @@ -282,10 +389,6 @@ def forward_cuda( assert positions.ndim == 1 or positions.ndim == 2 assert key is not None - if self.mrope_interleaved: - # TODO: add triton implementation to support mrope-interleaved - return self.forward_native(positions, query, key) - num_tokens = positions.shape[-1] cos_sin = self.cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) @@ -302,6 +405,7 @@ def forward_cuda( self.mrope_section, self.head_size, self.rotary_dim, + self.mrope_interleaved, ) return q.reshape(query_shape), k.reshape(key_shape) From b04bf4edcde7e4191759d4f6236eafef169bf6f5 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 17 Sep 2025 15:44:42 +0800 Subject: [PATCH 02/10] init test Signed-off-by: Isotr0py --- tests/kernels/core/test_mrope.py | 2 ++ vllm/model_executor/layers/rotary_embedding/mrope.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/kernels/core/test_mrope.py b/tests/kernels/core/test_mrope.py index 3f2f330f6dc3..e1e964cba2ed 100644 --- a/tests/kernels/core/test_mrope.py +++ b/tests/kernels/core/test_mrope.py @@ -43,6 +43,7 @@ def unroll_model_tp_dict(model_tp_dict): "Qwen/Qwen2-VL-7B-Instruct": [1, 2], "Qwen/Qwen2-VL-72B-Instruct": [1, 2], "Qwen/Qwen2.5-VL-72B-Instruct": [1, 2], + "Qwen/Qwen3-VL-4B-Instruct": [1, 2], "zai-org/GLM-4.1V-9B-Thinking": [1, 2], } @@ -115,6 +116,7 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens): "model_name, tp_size", unroll_model_tp_dict({ "Qwen/Qwen2-VL-7B-Instruct": [1, 2], + "Qwen/Qwen3-VL-4B-Instruct": [1, 2], "zai-org/GLM-4.1V-9B-Thinking": [1, 2] })) @pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 3869e10f7b11..4359b59a4eb3 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -146,8 +146,8 @@ def _triton_interleaved_mrope_forward( k_ptr = k_ptr + pid * (n_kh * hd) # #################################################################### - # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position - # m of this program instance + # get the cos(mθ_{i...d, step=2}) and sin(mθ_{i...d, step=2}) + # for token position m of this program instance # #################################################################### # Note: cos and sin now have shape (3, num_tokens, head_dim // 2) From 4c1642e31d1eedbb037ffece9d42dba974ed9b99 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 17 Sep 2025 18:21:51 +0800 Subject: [PATCH 03/10] fix sin cos layout Signed-off-by: Isotr0py --- tests/kernels/core/test_mrope.py | 10 +++++++--- .../model_executor/layers/rotary_embedding/mrope.py | 13 ++++++------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/kernels/core/test_mrope.py b/tests/kernels/core/test_mrope.py index e1e964cba2ed..08b3193ede5b 100644 --- a/tests/kernels/core/test_mrope.py +++ b/tests/kernels/core/test_mrope.py @@ -49,7 +49,7 @@ def unroll_model_tp_dict(model_tp_dict): # https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317 dtype_atol_rtol_list = [ - [torch.bfloat16, 1e-2, 1.6e-2], + [torch.bfloat16, 1.3e-2, 1.6e-2], ] num_tokens_list = [11, 8192] @@ -64,13 +64,15 @@ def unroll_model_tp_dict(model_tp_dict): def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens): config = AutoConfig.from_pretrained(model_name) + config = config.get_text_config() # get the model config total_num_kv_heads = config.num_key_value_heads total_num_heads = config.num_attention_heads num_heads = total_num_heads // tp_size num_kv_heads = max(1, total_num_kv_heads // tp_size) - head_dim = config.hidden_size // total_num_heads + head_dim = (config.head_dim if hasattr(config, "head_dim") else + config.hidden_size // total_num_heads) is_neox_style = True rope_theta = config.rope_theta @@ -124,13 +126,15 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens): def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol, num_tokens): config = AutoConfig.from_pretrained(model_name) + config = config.get_text_config() # get the model config total_num_kv_heads = config.num_key_value_heads total_num_heads = config.num_attention_heads num_heads = total_num_heads // tp_size num_kv_heads = max(1, total_num_kv_heads // tp_size) - head_dim = config.hidden_size // total_num_heads + head_dim = (config.head_dim if hasattr(config, "head_dim") else + config.hidden_size // total_num_heads) is_neox_style = True rope_theta = config.rope_theta max_position = config.max_position_embeddings diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 4359b59a4eb3..89d6f31bd42f 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -151,9 +151,6 @@ def _triton_interleaved_mrope_forward( # #################################################################### # Note: cos and sin now have shape (3, num_tokens, head_dim // 2) - t_end = mrope_section_t - h_end = t_end + mrope_section_h - # Updated stride calculation for half head_dim half_rd = rd // 2 t_cos = cos + pid * half_rd @@ -164,10 +161,12 @@ def _triton_interleaved_mrope_forward( w_sin = h_sin + num_tokens * half_rd # Updated offsets for half head_dim - cos_offsets = tl.arange(0, pad_hd // 2) * 2 - t_mask = cos_offsets < t_end - h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) - w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd) + # create interleaved mask + # [TTT...HHH...WWW] -> [THTHWHTHW...TT] + cos_offsets = tl.arange(0, pad_hd // 2) + t_mask = (cos_offsets % 3) == 0 + h_mask = (cos_offsets % 3) == 1 + w_mask = (cos_offsets % 3) == 2 t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0) h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0) From a16d76bc49960e4d18bdc3f97c698f998d7a0af9 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 17 Sep 2025 23:13:19 +0800 Subject: [PATCH 04/10] include qwen3_vl_moe Signed-off-by: Isotr0py --- tests/kernels/core/test_mrope.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/kernels/core/test_mrope.py b/tests/kernels/core/test_mrope.py index 08b3193ede5b..bb101b48f815 100644 --- a/tests/kernels/core/test_mrope.py +++ b/tests/kernels/core/test_mrope.py @@ -44,6 +44,7 @@ def unroll_model_tp_dict(model_tp_dict): "Qwen/Qwen2-VL-72B-Instruct": [1, 2], "Qwen/Qwen2.5-VL-72B-Instruct": [1, 2], "Qwen/Qwen3-VL-4B-Instruct": [1, 2], + "Qwen/Qwen3-VL-30B-A3B-Instruct": [1, 2], "zai-org/GLM-4.1V-9B-Thinking": [1, 2], } @@ -119,6 +120,7 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens): unroll_model_tp_dict({ "Qwen/Qwen2-VL-7B-Instruct": [1, 2], "Qwen/Qwen3-VL-4B-Instruct": [1, 2], + "Qwen/Qwen3-VL-30B-A3B-Instruct": [1, 2], "zai-org/GLM-4.1V-9B-Thinking": [1, 2] })) @pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list) From df4b50e5c1abf46c51162fe76b115cc9cfc7f9ad Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 18 Sep 2025 00:14:24 +0800 Subject: [PATCH 05/10] update test Signed-off-by: Isotr0py --- tests/kernels/core/test_mrope.py | 96 +++++++++++++++++++++----------- 1 file changed, 62 insertions(+), 34 deletions(-) diff --git a/tests/kernels/core/test_mrope.py b/tests/kernels/core/test_mrope.py index bb101b48f815..bc1c36d6405a 100644 --- a/tests/kernels/core/test_mrope.py +++ b/tests/kernels/core/test_mrope.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import NamedTuple import pytest import torch +from packaging.version import Version from transformers import AutoConfig +from transformers import __version__ as TRANSFORMERS_VERSION from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform @@ -15,6 +18,7 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int, head_size: int, max_position_embeddings: int, dtype: torch.dtype, device: torch.device): """Generate test data for given configuration.""" + current_platform.seed_everything(42) # Create 2D positions (3, num_tokens) for multimodal case positions = torch.randint(0, max_position_embeddings // 4, (3, num_tokens), @@ -33,24 +37,39 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int, return positions, query, key -def unroll_model_tp_dict(model_tp_dict): - return [(model_name, tp_size) - for model_name, tp_sizes in model_tp_dict.items() - for tp_size in tp_sizes] - - -model_tp_dict = { - "Qwen/Qwen2-VL-7B-Instruct": [1, 2], - "Qwen/Qwen2-VL-72B-Instruct": [1, 2], - "Qwen/Qwen2.5-VL-72B-Instruct": [1, 2], - "Qwen/Qwen3-VL-4B-Instruct": [1, 2], - "Qwen/Qwen3-VL-30B-A3B-Instruct": [1, 2], - "zai-org/GLM-4.1V-9B-Thinking": [1, 2], -} - -# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317 -dtype_atol_rtol_list = [ - [torch.bfloat16, 1.3e-2, 1.6e-2], +class MRoPETestInfo(NamedTuple): + model_name: str + # https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317 + atol: float = 1e-2 + rtol: float = 1.6e-2 + marks: list[pytest.MarkDecorator] = [] + + +TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version + +MODELS_TO_TEST = [ + MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"), + MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"), + MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"), + MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"), + MRoPETestInfo( + model_name="Qwen/Qwen3-VL-4B-Instruct", + atol=1.3e-2, + marks=[ + pytest.mark.skipif( + Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"), + reason="Qwen3-VL only available after Transformers v4.57", + ) + ]), + MRoPETestInfo( + model_name="Qwen/Qwen3-VL-30B-A3B-Instruct", + atol=2.5e-1, + marks=[ + pytest.mark.skipif( + Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"), + reason="Qwen3-VL only available after Transformers v4.57", + ) + ]), ] num_tokens_list = [11, 8192] @@ -58,11 +77,18 @@ def unroll_model_tp_dict(model_tp_dict): @pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests.") -@pytest.mark.parametrize("model_name, tp_size", - unroll_model_tp_dict(model_tp_dict)) -@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list) +@pytest.mark.parametrize("model_info, model_name", [ + pytest.param(test_config, test_config.model_name, marks=test_config.marks) + for test_config in MODELS_TO_TEST +]) +@pytest.mark.parametrize("tp_size", [1, 2]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("num_tokens", num_tokens_list) -def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens): +def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int, + dtype: torch.dtype, num_tokens: int): + + atol = model_info.atol + rtol = model_info.rtol config = AutoConfig.from_pretrained(model_name) config = config.get_text_config() @@ -115,18 +141,20 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens): @pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests.") -@pytest.mark.parametrize( - "model_name, tp_size", - unroll_model_tp_dict({ - "Qwen/Qwen2-VL-7B-Instruct": [1, 2], - "Qwen/Qwen3-VL-4B-Instruct": [1, 2], - "Qwen/Qwen3-VL-30B-A3B-Instruct": [1, 2], - "zai-org/GLM-4.1V-9B-Thinking": [1, 2] - })) -@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list) -@pytest.mark.parametrize("num_tokens", [4]) -def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol, - num_tokens): +@pytest.mark.parametrize("model_info, model_name", [ + pytest.param(test_config, test_config.model_name, marks=test_config.marks) + for test_config in MODELS_TO_TEST +]) +@pytest.mark.parametrize("tp_size", [1, 2]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("num_tokens", num_tokens_list) +def test_mrope_torch_compile_tracing(model_name: str, + model_info: MRoPETestInfo, tp_size: int, + dtype: torch.dtype, num_tokens: int): + + atol = model_info.atol + rtol = model_info.rtol + config = AutoConfig.from_pretrained(model_name) config = config.get_text_config() From 1130de2bc51c07701b7e1d64994c22a5273d6812 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 18 Sep 2025 00:23:20 +0800 Subject: [PATCH 06/10] consolidate kernel Signed-off-by: Isotr0py --- .../layers/rotary_embedding/mrope.py | 124 ++---------------- 1 file changed, 13 insertions(+), 111 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 89d6f31bd42f..2955210db313 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -15,7 +15,7 @@ @triton.jit -def _triton_qwen2vl_mrope_forward( +def _triton_mrope_forward( q_ptr, k_ptr, cos, @@ -30,12 +30,13 @@ def _triton_qwen2vl_mrope_forward( pad_hd: tl.constexpr, mrope_section_t: tl.constexpr, mrope_section_h: tl.constexpr, + is_interleaved: tl.constexpr, ): # Adapted from # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py # This version supports flatten input tensors from vllm # and supports cos and sin cache with shape (3, num_tokens, head_dim // 2) - # instead of (3, bsz, seq_len, head_dim) + # instead of (3, bsz, seq_len, head_dim), also supports interleaved rotary pid = tl.program_id(0) # locate start address q_ptr = q_ptr + pid * (n_qh * hd) @@ -61,112 +62,14 @@ def _triton_qwen2vl_mrope_forward( # Updated offsets for half head_dim cos_offsets = tl.arange(0, pad_hd // 2) - t_mask = cos_offsets < t_end - h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) - w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd) - - t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0) - h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0) - w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0) - t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0) - h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0) - w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0) - - cos_row = t_cos_row + h_cos_row + w_cos_row - sin_row = t_sin_row + h_sin_row + w_sin_row - - # #################################################################### - # Load the left and right half of q and k for the current - # program instance (i.e. for the current token) separately - # #################################################################### - # left half of the head - first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange( - 0, pad_hd // 2)[None, :] - first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange( - 0, pad_hd // 2)[None, :] - first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange( - 0, pad_hd // 2)[None, :] < rd // 2) - first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange( - 0, pad_hd // 2)[None, :] < rd // 2) - - q_tile_1 = tl.load(q_ptr + first_half_q_offsets, - mask=first_q_mask, - other=0).to(sin_row.dtype) - k_tile_1 = tl.load(k_ptr + first_half_k_offsets, - mask=first_k_mask, - other=0).to(sin_row.dtype) - - # right half of the head - second_half_q_offsets = first_half_q_offsets + (rd // 2) - second_half_k_offsets = first_half_k_offsets + (rd // 2) - second_q_mask = first_q_mask - second_k_mask = first_k_mask - - q_tile_2 = tl.load(q_ptr + second_half_q_offsets, - mask=second_q_mask, - other=0).to(sin_row.dtype) - k_tile_2 = tl.load(k_ptr + second_half_k_offsets, - mask=second_k_mask, - other=0).to(sin_row.dtype) - - # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] - # Since cos and sin are now half-size, - # we use the same cos_row and sin_row for both halves - new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row - tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) - new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row - tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) - - new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row - tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) - new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row - tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) - - -@triton.jit -def _triton_interleaved_mrope_forward( - q_ptr, - k_ptr, - cos, - sin, - num_tokens, - n_qh: tl.constexpr, - n_kh: tl.constexpr, - hd: tl.constexpr, - rd: tl.constexpr, - pad_n_qh: tl.constexpr, - pad_n_kh: tl.constexpr, - pad_hd: tl.constexpr, - mrope_section_t: tl.constexpr, - mrope_section_h: tl.constexpr, -): - pid = tl.program_id(0) - # locate start address - q_ptr = q_ptr + pid * (n_qh * hd) - k_ptr = k_ptr + pid * (n_kh * hd) - - # #################################################################### - # get the cos(mθ_{i...d, step=2}) and sin(mθ_{i...d, step=2}) - # for token position m of this program instance - # #################################################################### - # Note: cos and sin now have shape (3, num_tokens, head_dim // 2) - - # Updated stride calculation for half head_dim - half_rd = rd // 2 - t_cos = cos + pid * half_rd - h_cos = t_cos + num_tokens * half_rd - w_cos = h_cos + num_tokens * half_rd - t_sin = sin + pid * half_rd - h_sin = t_sin + num_tokens * half_rd - w_sin = h_sin + num_tokens * half_rd - - # Updated offsets for half head_dim - # create interleaved mask - # [TTT...HHH...WWW] -> [THTHWHTHW...TT] - cos_offsets = tl.arange(0, pad_hd // 2) - t_mask = (cos_offsets % 3) == 0 - h_mask = (cos_offsets % 3) == 1 - w_mask = (cos_offsets % 3) == 2 + if is_interleaved: + t_mask = (cos_offsets % 3) == 0 + h_mask = (cos_offsets % 3) == 1 + w_mask = (cos_offsets % 3) == 2 + else: + t_mask = cos_offsets < t_end + h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) + w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd) t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0) h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0) @@ -262,9 +165,7 @@ def triton_mrope( cos = cos.contiguous() sin = sin.contiguous() - kernel = (_triton_interleaved_mrope_forward - if mrope_interleaved else _triton_qwen2vl_mrope_forward) - kernel[(n_row, )]( + _triton_mrope_forward[(n_row, )]( q, k, cos, @@ -279,6 +180,7 @@ def triton_mrope( pad_hd, mrope_section[0], mrope_section[1], + mrope_interleaved, ) return q, k From 2d79d757eb2ccd370b617747461aec04091abe75 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 18 Sep 2025 00:34:27 +0800 Subject: [PATCH 07/10] remove redundant computation for interleaved mrope Signed-off-by: Isotr0py --- vllm/model_executor/layers/rotary_embedding/mrope.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 2955210db313..4e567a19264a 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -48,9 +48,6 @@ def _triton_mrope_forward( # #################################################################### # Note: cos and sin now have shape (3, num_tokens, head_dim // 2) - t_end = mrope_section_t - h_end = t_end + mrope_section_h - # Updated stride calculation for half head_dim half_rd = rd // 2 t_cos = cos + pid * half_rd @@ -67,7 +64,9 @@ def _triton_mrope_forward( h_mask = (cos_offsets % 3) == 1 w_mask = (cos_offsets % 3) == 2 else: - t_mask = cos_offsets < t_end + t_end = mrope_section_t + h_end = t_end + mrope_section_h + t_mask = cos_offsets < mrope_section_t h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd) From f2926335025e22d09810244df9dab6712d392ea9 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 18 Sep 2025 01:57:32 +0800 Subject: [PATCH 08/10] fix t Signed-off-by: Isotr0py --- vllm/model_executor/layers/rotary_embedding/mrope.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 4e567a19264a..2ccbe54bffc1 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -30,6 +30,7 @@ def _triton_mrope_forward( pad_hd: tl.constexpr, mrope_section_t: tl.constexpr, mrope_section_h: tl.constexpr, + mrope_section_w: tl.constexpr, is_interleaved: tl.constexpr, ): # Adapted from @@ -60,9 +61,9 @@ def _triton_mrope_forward( # Updated offsets for half head_dim cos_offsets = tl.arange(0, pad_hd // 2) if is_interleaved: - t_mask = (cos_offsets % 3) == 0 - h_mask = (cos_offsets % 3) == 1 - w_mask = (cos_offsets % 3) == 2 + h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h) + w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w) + t_mask = ~(h_mask | w_mask) else: t_end = mrope_section_t h_end = t_end + mrope_section_h @@ -179,6 +180,7 @@ def triton_mrope( pad_hd, mrope_section[0], mrope_section[1], + mrope_section[2], mrope_interleaved, ) return q, k From 4bd57ed02783beff85e752a1aeff0953e7595d7f Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 18 Sep 2025 01:59:22 +0800 Subject: [PATCH 09/10] fix t Signed-off-by: Isotr0py --- tests/kernels/core/test_mrope.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/kernels/core/test_mrope.py b/tests/kernels/core/test_mrope.py index bc1c36d6405a..5a903438f5e9 100644 --- a/tests/kernels/core/test_mrope.py +++ b/tests/kernels/core/test_mrope.py @@ -54,7 +54,6 @@ class MRoPETestInfo(NamedTuple): MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"), MRoPETestInfo( model_name="Qwen/Qwen3-VL-4B-Instruct", - atol=1.3e-2, marks=[ pytest.mark.skipif( Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"), @@ -63,7 +62,6 @@ class MRoPETestInfo(NamedTuple): ]), MRoPETestInfo( model_name="Qwen/Qwen3-VL-30B-A3B-Instruct", - atol=2.5e-1, marks=[ pytest.mark.skipif( Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"), From 9fb001d43abf1c66375434b3d246b4c58c011b14 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 18 Sep 2025 02:17:42 +0800 Subject: [PATCH 10/10] code format Signed-off-by: Isotr0py --- vllm/model_executor/layers/rotary_embedding/mrope.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 2ccbe54bffc1..ccc59bbbe233 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -61,8 +61,10 @@ def _triton_mrope_forward( # Updated offsets for half head_dim cos_offsets = tl.arange(0, pad_hd // 2) if is_interleaved: - h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h) - w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w) + h_mask = (((cos_offsets % 3) == 1) & + (cos_offsets <= 3 * mrope_section_h)) + w_mask = (((cos_offsets % 3) == 2) & + (cos_offsets <= 3 * mrope_section_w)) t_mask = ~(h_mask | w_mask) else: t_end = mrope_section_t