Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 66 additions & 32 deletions tests/kernels/core/test_mrope.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import NamedTuple

import pytest
import torch
from packaging.version import Version
from transformers import AutoConfig
from transformers import __version__ as TRANSFORMERS_VERSION

from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
Expand All @@ -15,6 +18,7 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
head_size: int, max_position_embeddings: int,
dtype: torch.dtype, device: torch.device):
"""Generate test data for given configuration."""
current_platform.seed_everything(42)
# Create 2D positions (3, num_tokens) for multimodal case
positions = torch.randint(0,
max_position_embeddings // 4, (3, num_tokens),
Expand All @@ -33,43 +37,67 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
return positions, query, key


def unroll_model_tp_dict(model_tp_dict):
return [(model_name, tp_size)
for model_name, tp_sizes in model_tp_dict.items()
for tp_size in tp_sizes]


model_tp_dict = {
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
"Qwen/Qwen2-VL-72B-Instruct": [1, 2],
"Qwen/Qwen2.5-VL-72B-Instruct": [1, 2],
"zai-org/GLM-4.1V-9B-Thinking": [1, 2],
}

# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
dtype_atol_rtol_list = [
[torch.bfloat16, 1e-2, 1.6e-2],
class MRoPETestInfo(NamedTuple):
model_name: str
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
atol: float = 1e-2
rtol: float = 1.6e-2
marks: list[pytest.MarkDecorator] = []


TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version

MODELS_TO_TEST = [
MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"),
MRoPETestInfo(
model_name="Qwen/Qwen3-VL-4B-Instruct",
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
]),
MRoPETestInfo(
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct",
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
]),
]

num_tokens_list = [11, 8192]


@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize("model_name, tp_size",
unroll_model_tp_dict(model_tp_dict))
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
@pytest.mark.parametrize("model_info, model_name", [
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
for test_config in MODELS_TO_TEST
])
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int,
dtype: torch.dtype, num_tokens: int):

atol = model_info.atol
rtol = model_info.rtol

config = AutoConfig.from_pretrained(model_name)
config = config.get_text_config()

# get the model config
total_num_kv_heads = config.num_key_value_heads
total_num_heads = config.num_attention_heads
num_heads = total_num_heads // tp_size
num_kv_heads = max(1, total_num_kv_heads // tp_size)
head_dim = config.hidden_size // total_num_heads
head_dim = (config.head_dim if hasattr(config, "head_dim") else
config.hidden_size // total_num_heads)
is_neox_style = True

rope_theta = config.rope_theta
Expand Down Expand Up @@ -111,24 +139,30 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):

@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize(
"model_name, tp_size",
unroll_model_tp_dict({
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
"zai-org/GLM-4.1V-9B-Thinking": [1, 2]
}))
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
@pytest.mark.parametrize("num_tokens", [4])
def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
num_tokens):
@pytest.mark.parametrize("model_info, model_name", [
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
for test_config in MODELS_TO_TEST
])
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope_torch_compile_tracing(model_name: str,
model_info: MRoPETestInfo, tp_size: int,
dtype: torch.dtype, num_tokens: int):

atol = model_info.atol
rtol = model_info.rtol

config = AutoConfig.from_pretrained(model_name)
config = config.get_text_config()

# get the model config
total_num_kv_heads = config.num_key_value_heads
total_num_heads = config.num_attention_heads
num_heads = total_num_heads // tp_size
num_kv_heads = max(1, total_num_kv_heads // tp_size)
head_dim = config.hidden_size // total_num_heads
head_dim = (config.head_dim if hasattr(config, "head_dim") else
config.hidden_size // total_num_heads)
is_neox_style = True
rope_theta = config.rope_theta
max_position = config.max_position_embeddings
Expand Down
36 changes: 22 additions & 14 deletions vllm/model_executor/layers/rotary_embedding/mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


@triton.jit
def _triton_qwen2vl_mrope_forward(
def _triton_mrope_forward(
q_ptr,
k_ptr,
cos,
Expand All @@ -30,12 +30,14 @@ def _triton_qwen2vl_mrope_forward(
pad_hd: tl.constexpr,
mrope_section_t: tl.constexpr,
mrope_section_h: tl.constexpr,
mrope_section_w: tl.constexpr,
is_interleaved: tl.constexpr,
):
# Adapted from
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
# This version supports flatten input tensors from vllm
# and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
# instead of (3, bsz, seq_len, head_dim)
# instead of (3, bsz, seq_len, head_dim), also supports interleaved rotary
pid = tl.program_id(0)
# locate start address
q_ptr = q_ptr + pid * (n_qh * hd)
Expand All @@ -47,9 +49,6 @@ def _triton_qwen2vl_mrope_forward(
# ####################################################################
# Note: cos and sin now have shape (3, num_tokens, head_dim // 2)

t_end = mrope_section_t
h_end = t_end + mrope_section_h

# Updated stride calculation for half head_dim
half_rd = rd // 2
t_cos = cos + pid * half_rd
Expand All @@ -61,9 +60,18 @@ def _triton_qwen2vl_mrope_forward(

# Updated offsets for half head_dim
cos_offsets = tl.arange(0, pad_hd // 2)
t_mask = cos_offsets < t_end
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
if is_interleaved:
h_mask = (((cos_offsets % 3) == 1) &
(cos_offsets <= 3 * mrope_section_h))
w_mask = (((cos_offsets % 3) == 2) &
(cos_offsets <= 3 * mrope_section_w))
t_mask = ~(h_mask | w_mask)
else:
t_end = mrope_section_t
h_end = t_end + mrope_section_h
t_mask = cos_offsets < mrope_section_t
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)

t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
Expand Down Expand Up @@ -131,6 +139,7 @@ def triton_mrope(
mrope_section: list[int],
head_size: int,
rotary_dim: int,
mrope_interleaved: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Qwen2VL mrope kernel.

Expand Down Expand Up @@ -158,7 +167,7 @@ def triton_mrope(
cos = cos.contiguous()
sin = sin.contiguous()

_triton_qwen2vl_mrope_forward[(n_row, )](
_triton_mrope_forward[(n_row, )](
q,
k,
cos,
Expand All @@ -173,6 +182,8 @@ def triton_mrope(
pad_hd,
mrope_section[0],
mrope_section[1],
mrope_section[2],
mrope_interleaved,
)
return q, k

Expand Down Expand Up @@ -201,7 +212,7 @@ def __init__(
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: Optional[list[int]] = None,
mrope_interleaved: Optional[bool] = False,
mrope_interleaved: bool = False,
) -> None:
# In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get
Expand Down Expand Up @@ -282,10 +293,6 @@ def forward_cuda(
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None

if self.mrope_interleaved:
# TODO: add triton implementation to support mrope-interleaved
return self.forward_native(positions, query, key)

num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
Expand All @@ -302,6 +309,7 @@ def forward_cuda(
self.mrope_section,
self.head_size,
self.rotary_dim,
self.mrope_interleaved,
)

return q.reshape(query_shape), k.reshape(key_shape)
Expand Down