Skip to content

Commit 7c50ed1

Browse files
committed
Move sdpa to custom op until tensor slicing supported
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
1 parent d6704dd commit 7c50ed1

File tree

2 files changed

+60
-17
lines changed

2 files changed

+60
-17
lines changed

vllm/attention/ops/vit_attn_wrappers.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import einops
1616
import torch
17+
import torch.nn.functional as F
1718

1819
from vllm.utils.torch_utils import direct_register_custom_op
1920

@@ -123,3 +124,55 @@ def vit_flash_attn_wrapper(
123124
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
124125
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
125126
)
127+
128+
129+
# TODO: Once we have a torch 2.10, we can use tensor slices
130+
# so we won't need to wrap this in custom ops
131+
def torch_sdpa_wrapper(
132+
q: torch.Tensor,
133+
k: torch.Tensor,
134+
v: torch.Tensor,
135+
cu_seqlens: torch.Tensor,
136+
) -> torch.Tensor:
137+
outputs = []
138+
for i in range(1, len(cu_seqlens)):
139+
start_idx = cu_seqlens[i - 1]
140+
end_idx = cu_seqlens[i]
141+
q_i = q[:, start_idx:end_idx]
142+
k_i = k[:, start_idx:end_idx]
143+
v_i = v[:, start_idx:end_idx]
144+
q_i, k_i, v_i = (
145+
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
146+
)
147+
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
148+
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
149+
outputs.append(output_i)
150+
context_layer = torch.cat(outputs, dim=1)
151+
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
152+
return context_layer
153+
154+
155+
def torch_sdpa_wrapper_fake(
156+
q: torch.Tensor,
157+
k: torch.Tensor,
158+
v: torch.Tensor,
159+
cu_seqlens: torch.Tensor,
160+
) -> torch.Tensor:
161+
b, s, h, d = q.shape
162+
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
163+
164+
165+
direct_register_custom_op(
166+
op_name="torch_sdpa_wrapper",
167+
op_func=torch_sdpa_wrapper,
168+
fake_impl=torch_sdpa_wrapper_fake,
169+
)
170+
171+
172+
def vit_torch_sdpa_wrapper(
173+
q: torch.Tensor,
174+
k: torch.Tensor,
175+
v: torch.Tensor,
176+
cu_seqlens: torch.Tensor,
177+
) -> torch.Tensor:
178+
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
)
5050
from vllm.attention.ops.vit_attn_wrappers import (
5151
vit_flash_attn_wrapper,
52+
vit_torch_sdpa_wrapper,
5253
vit_xformers_attn_wrapper,
5354
)
5455
from vllm.compilation.decorators import support_torch_compile
@@ -428,23 +429,12 @@ def forward(
428429
)
429430
elif self.attn_backend == _Backend.TORCH_SDPA:
430431
# Execute attention entry by entry for speed & less VRAM.
431-
outputs = []
432-
for i in range(1, len(cu_seqlens)):
433-
start_idx = cu_seqlens[i - 1]
434-
end_idx = cu_seqlens[i]
435-
q_i = q[:, start_idx:end_idx]
436-
k_i = k[:, start_idx:end_idx]
437-
v_i = v[:, start_idx:end_idx]
438-
q_i, k_i, v_i = (
439-
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
440-
)
441-
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
442-
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
443-
outputs.append(output_i)
444-
context_layer = torch.cat(outputs, dim=1)
445-
context_layer = einops.rearrange(
446-
context_layer, "b s h d -> s b (h d)"
447-
).contiguous()
432+
context_layer = vit_torch_sdpa_wrapper(
433+
q,
434+
k,
435+
v,
436+
cu_seqlens,
437+
)
448438
elif self.attn_backend == _Backend.XFORMERS:
449439
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
450440

0 commit comments

Comments
 (0)