diff --git a/tests/kernels/test_flex_attention.py b/tests/kernels/test_flex_attention.py index 040ddac10258..74d29e79d96c 100644 --- a/tests/kernels/test_flex_attention.py +++ b/tests/kernels/test_flex_attention.py @@ -51,7 +51,6 @@ def test_flex_attention_vs_default_backend(monkeypatch): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") - m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") set_seed(seed) @@ -66,7 +65,6 @@ def test_flex_attention_vs_default_backend(monkeypatch): # Run with default backend with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") set_seed(seed) llm_default = LLM( model_name, diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 17b0f259cb76..dd8d7994ed33 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -13,7 +13,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, is_quantized_kv_cache) -from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, @@ -237,17 +236,13 @@ def final_mask_mod( def build_block_mask(self) -> BlockMask: assert self.mask_mod is not None - # FIXME: With TP>1, create_block_mask_compiled will raise - # CUDA error: an illegal memory access was encountered - create_block_mask_fn = (create_block_mask_compiled - if get_tensor_model_parallel_world_size() == 1 - else create_block_mask) - return create_block_mask_fn( + return create_block_mask_compiled( self.mask_mod, None, None, self.num_actual_tokens, self.total_cache_tokens, + device=self.block_table.device, ) def __post_init__(self): @@ -429,7 +424,6 @@ def forward( shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported"