Skip to content

Commit b2155ed

Browse files
[Model][Qwen3VL] Compute cu_seqlens on CPU to remove (#26496)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 910abdb commit b2155ed

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

vllm/model_executor/models/qwen3_vl.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,9 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
488488

489489
indices = torch.stack([idx00, idx01, idx10, idx11], dim=0).reshape(4, -1)
490490
weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
491-
weights = weights.to(dtype=self.dtype, device=self.device)
491+
weights = weights.to(
492+
dtype=self.dtype, device=self.device, non_blocking=True
493+
)
492494

493495
embeds = self.pos_embed(indices)
494496
weighted_embeds = embeds * weights
@@ -524,14 +526,15 @@ def forward(
524526
x: torch.Tensor,
525527
grid_thw: list[list[int]],
526528
) -> torch.Tensor:
527-
hidden_states = x.to(device=self.device, dtype=self.dtype)
529+
hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
528530
hidden_states = self.patch_embed(hidden_states)
529531

530532
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
531533
hidden_states = hidden_states + pos_embeds
532534
rotary_pos_emb = self.rot_pos_emb(grid_thw)
535+
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, non_blocking=True)
533536

534-
grid_thw_tensor = torch.tensor(grid_thw, device=self.device, dtype=torch.int32)
537+
grid_thw_tensor = torch.tensor(grid_thw, dtype=torch.int32)
535538

536539
cu_seqlens = torch.repeat_interleave(
537540
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], grid_thw_tensor[:, 0]
@@ -542,8 +545,8 @@ def forward(
542545
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
543546

544547
hidden_states = hidden_states.unsqueeze(1)
545-
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
546548
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
549+
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
547550

548551
deepstack_feature_lists = []
549552
for layer_num, blk in enumerate(self.blocks):

0 commit comments

Comments
 (0)