Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tests/kernels/test_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def test_flex_attention_vs_default_backend(monkeypatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")

set_seed(seed)

Expand All @@ -66,7 +65,6 @@ def test_flex_attention_vs_default_backend(monkeypatch):
# Run with default backend
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
set_seed(seed)
llm_default = LLM(
model_name,
Expand Down
10 changes: 2 additions & 8 deletions vllm/v1/attention/backends/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
Expand Down Expand Up @@ -237,17 +236,13 @@ def final_mask_mod(

def build_block_mask(self) -> BlockMask:
assert self.mask_mod is not None
# FIXME: With TP>1, create_block_mask_compiled will raise
# CUDA error: an illegal memory access was encountered
create_block_mask_fn = (create_block_mask_compiled
if get_tensor_model_parallel_world_size() == 1
else create_block_mask)
return create_block_mask_fn(
return create_block_mask_compiled(
self.mask_mod,
None,
None,
self.num_actual_tokens,
self.total_cache_tokens,
device=self.block_table.device,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So passing in device is the key?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we have run into this before where torch.compiles lack of implicit device transfer causes it to show up as IMA,

I added some debug logging locall

(VllmWorker rank=1 pid=4120217) INFO 06-16 18:59:19 [flex_attention.py:240] create_block_mask_compiled called with device: cuda:1
(VllmWorker rank=6 pid=4120237) INFO 06-16 18:59:19 [flex_attention.py:240] create_block_mask_compiled called with device: cuda:6
(VllmWorker rank=7 pid=4120242) INFO 06-16 18:59:19 [flex_attention.py:240] create_block_mask_compiled called with device: cuda:7
(VllmWorker rank=3 pid=4120224) INFO 06-16 18:59:19 [flex_attention.py:240] create_block_mask_compiled called with device: cuda:3
(VllmWorker rank=4 pid=4120227) INFO 06-16 18:59:19 [flex_attention.py:240] create_block_mask_compiled called with device: cuda:4
(VllmWorker rank=0 pid=4120216) INFO 06-16 18:59:19 [flex_attention.py:240] create_block_mask_compiled called with device: cuda:0
(VllmWorker rank=5 pid=4120233) INFO 06-16 18:59:19 [flex_attention.py:240] create_block_mask_compiled called with device: cuda:5
(VllmWorker rank=2 pid=4120220) INFO 06-16 18:59:19 [flex_attention.py:240] create_block_mask_compiled called with device: cuda:2

)

def __post_init__(self):
Expand Down Expand Up @@ -429,7 +424,6 @@ def forward(
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."

if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
Expand Down