Skip to content

Commit

Permalink
[V1] Add V1 support of Qwen2-VL (vllm-project#12128)
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: imkero <kerorek@outlook.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Isotr0py <2037008807@qq.com>
  • Loading branch information
3 people authored and Isotr0py committed Feb 2, 2025
1 parent 5f21ee7 commit c3cacb8
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 85 deletions.
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ See [this page](#generative-models) for more information on how to use generativ
- `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc.
- ✅︎
- ✅︎
-
- ✅︎
* - `UltravoxModel`
- Ultravox
- T + A<sup>E+</sup>
Expand Down
18 changes: 8 additions & 10 deletions tests/models/decoder_only/vision_language/test_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def batch_make_image_embeddings(
pixel_values = preprocess_result["pixel_values"]
image_grid_thw = preprocess_result["image_grid_thw"]

# pixel values to embeddinds & grid_thws
# pixel values to embeddings & grid_thws
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker. \
model_runner.model.visual
Expand All @@ -124,11 +124,10 @@ def batch_make_image_embeddings(
for image_batch in image_batches_:
cur_batch_image_count = len(image_batch)
merge_size = image_processor.merge_size
cur_batch_embed_len = sum([
grid_thw.prod() // merge_size // merge_size
cur_batch_embed_len = sum(
grid_thw.prod(-1) // merge_size // merge_size
for grid_thw in image_grid_thw[image_counter:image_counter +
cur_batch_image_count]
])
cur_batch_image_count])

result.append({
"image_embeds":
Expand Down Expand Up @@ -187,7 +186,7 @@ def batch_make_video_embeddings(
pixel_values = preprocess_result["pixel_values_videos"]
video_grid_thw = preprocess_result["video_grid_thw"]

# pixel values to embeddinds & grid_thws
# pixel values to embeddings & grid_thws
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker.\
model_runner.model.visual
Expand All @@ -206,11 +205,10 @@ def batch_make_video_embeddings(
for video_batch in video_batches_:
cur_batch_video_count = len(video_batch)
merge_size = image_processor.merge_size
cur_batch_embed_len = sum([
grid_thw.prod() // merge_size // merge_size
cur_batch_embed_len = sum(
grid_thw.prod(-1) // merge_size // merge_size
for grid_thw in video_grid_thw[video_counter:video_counter +
cur_batch_video_count]
])
cur_batch_video_count])

result.append({
"video_embeds":
Expand Down
14 changes: 12 additions & 2 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
During runtime, when we actually mark dimensions of tensors,
it depends on the value of arguments:
- if it is a single integer, the corresponding dimension of the argument
will be marked as dynamic.
- if it is a single integer (can be negative), the corresponding dimension
of the argument will be marked as dynamic.
- if it is `None`, ignored.
- if it is `IntermediateTensors`, all the tensors in the intermediate
tensors will be marked as dynamic.
Expand Down Expand Up @@ -177,10 +177,20 @@ def __call__(self, *args, **kwargs):
for k, dims in dynamic_arg_dims.items():
arg = bound_args.arguments.get(k)
if arg is not None:
dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [
arg.ndim + dim if dim < 0 else dim for dim in dims
]
torch._dynamo.mark_dynamic(arg, dims)
elif isinstance(arg, IntermediateTensors):
for tensor in arg.tensors.values():
# In case dims is specified with negative indexing
dims = [
tensor.ndim + dim if dim < 0 else dim
for dim in dims
]
torch._dynamo.mark_dynamic(tensor, dims)
else:
raise ValueError(
Expand Down
44 changes: 43 additions & 1 deletion vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,37 @@ def get_input_positions(
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""

llm_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
input_tokens,
image_grid_thw,
video_grid_thw,
image_token_id,
video_token_id,
vision_start_token_id,
vision_end_token_id,
spatial_merge_size,
context_len,
seq_len,
)

return llm_positions.tolist(), mrope_position_delta

@staticmethod
def get_input_positions_tensor(
input_tokens: List[int],
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
context_len: int = 0,
seq_len: Optional[int] = None,
) -> Tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""

if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
Expand Down Expand Up @@ -916,7 +947,7 @@ def get_input_positions(
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]

return llm_positions.tolist(), mrope_position_delta
return llm_positions, mrope_position_delta

@staticmethod
def get_next_input_positions(
Expand All @@ -930,6 +961,17 @@ def get_next_input_positions(
seq_len + mrope_position_delta)) for _ in range(3)
]

@staticmethod
def get_next_input_positions_tensor(
mrope_position_delta: int,
context_len: int,
seq_len: int,
) -> torch.Tensor:
return torch.arange(
mrope_position_delta + context_len,
mrope_position_delta + seq_len,
).expand(3, -1)


_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}

Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/models/llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,10 +554,12 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key == "pixel_values" and "images" not in modalities:
if input_key in ("pixel_values",
"image_embeds") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if input_key == "pixel_values_videos" and "videos" not in modalities: # noqa E501
if input_key in ("pixel_values_videos",
"video_embeds") and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)

Expand Down
10 changes: 9 additions & 1 deletion vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,15 @@ def forward(
return hidden_states, residual


@support_torch_compile
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class Qwen2Model(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down
Loading

0 comments on commit c3cacb8

Please sign in to comment.