Skip to content

Conversation

@Isotr0py
Copy link
Member

@Isotr0py Isotr0py commented Sep 2, 2025

Purpose

Test Plan

pytest -s -v tests/v1/attention/test_attention_backends.py

Test Result

Test should still pass.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@mergify mergify bot added the v1 label Sep 2, 2025
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@Isotr0py Isotr0py marked this pull request as ready for review September 4, 2025 07:06
@Isotr0py Isotr0py requested a review from zou3519 September 4, 2025 14:35
@Isotr0py
Copy link
Member Author

Isotr0py commented Sep 6, 2025

Also cc @drisspg for visibility


def build_block_mask(self) -> BlockMask:
if self.causal:
if self.sliding_window is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this would still work w/ the direct build path are your new test checking this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 181f15d, the new test can cover the direct build code path, it's disabled for torch2.8 currently:

    use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0")
    if backend == "FLEX_ATTENTION_SLOW":
        actual_backend = _Backend.FLEX_ATTENTION
        use_direct_block_mask = False

I modified it to this locally and confirmed it passed on torch2.8:

    use_direct_block_mask = True
    if backend == "FLEX_ATTENTION_SLOW":
        actual_backend = _Backend.FLEX_ATTENTION
        use_direct_block_mask = False

But I didn't push it in that commit, just in case that there are something need to disable direct build.

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@noooop
Copy link
Collaborator

noooop commented Sep 8, 2025

Please test Alibaba-NLP/gte-reranker-modernbert-base and google/embeddinggemma-300m (need to manually set dtype = float32) to ensure the results of bi-directional attention + sliding window + Flex Attention are correct

pytest tests/models/language/pooling/test_st_projector.py::test_embed_models_mteb[model_info1] #24318
pytest tests/models/language/pooling/test_gte.py::test_rerank_models_mteb[model_info0] #22637

@Isotr0py
Copy link
Member Author

Isotr0py commented Sep 8, 2025

Please test Alibaba-NLP/gte-reranker-modernbert-base and google/embeddinggemma-300m (need to manually set dtype = float32) to ensure the results of bi-directional attention + sliding window + Flex Attention are correct

@noooop Have confirmed both tests passed with fp32 locally now:

Model: Alibaba-NLP/gte-reranker-modernbert-base
VLLM: torch.float32 0.33447
SentenceTransformers: Constant 0.33386
Difference: -0.0006099999999999994
PASSED
Model: google/embeddinggemma-300m
VLLM: torch.float32 0.7473782217631502
SentenceTransformers: Constant 0.7473819294684156
Difference: 3.7077052654765907e-06
PASSED

@heheda12345
Copy link
Collaborator

But as _build_block_mask_direct didn't implement sliding window, the block mask is still that of full attention. This implementation can have correct output but will be slow as it doesn't skip the computation of sliding window. Is my understanding correct? (If this is true, I think we need a warning)

              attn_metadata.block_mask = (
                    attn_metadata._build_block_mask_direct())

@Isotr0py
Copy link
Member Author

Isotr0py commented Sep 15, 2025

But as _build_block_mask_direct didn't implement sliding window, the block mask is still that of full attention.

Different from build_block_mask, _build_block_mask_direct doesn't prepare mask_mod inside itself, it always use attn_metadata.mask_mod to create block mask:

block_mask_kwargs = {
"seq_lengths": (self.num_actual_tokens, self.total_cache_tokens),
"kv_num_blocks": kv_num_blocks[None, None],
"kv_indices": kv_indices[None, None],
"full_kv_num_blocks": None,
"full_kv_indices": None,
"BLOCK_SIZE": (self.q_block_size, self.kv_block_size),
"mask_mod": self.mask_mod,
}

In fact, the key is attn_metadata.mask_mod = attn_metadata.get_mask_mod(), which is called before _build_block_mask_direct (though this implementation is a little bit hacky):

                # update mask mod in attention metadata
                attn_metadata.mask_mod = attn_metadata.get_mask_mod()
                attn_metadata.block_mask = (
                    attn_metadata._build_block_mask_direct())

If attn_metadata.sliding_window is set to a valid value (e.g switching stage full attn -> sliding window), attn_metadata.get_mask_mod() will create a sliding window mask_mod and _build_block_mask_direct() will use it to create sliding window block mask.

At the switching stage for sliding window -> full attn, attn_metadata.sliding_window will be reset to None, then attn_metadata.mask_mod = attn_metadata.get_mask_mod() will rebuild a full attention mask for _build_block_mask_direct().

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think as is this fine, but can you create some issues on the follow up work. The dynamic creating of block_mask is not ideal

@heheda12345
Copy link
Collaborator

Did you try some e2e model? I tried google/gemma-3-4b-it google/gemma-3-1b-it openai/gpt-oss-20b on H100 and get different errors

VLLM_ATTENTION_BACKEND=FLEX_ATTENTION vllm bench throughput --model MODEL --input-len 1024 --output-len 1024

The logs are attached
gemma-3-1b.log
gemma-3-4b.log
gpt-oss.log

@Isotr0py
Copy link
Member Author

Did you try some e2e model?

I tried microsoft/Phi-tiny-MoE-instruct (model-level sliding window) and Alibaba-NLP/gte-reranker-modernbert-base (encoder-only interleaved sliding window), both can generate outputs normally. Seems this issue only occured on decoder-only interleaved sliding window models.

I just ran google/gemma-3-1b-it, but got another error different from above logs, which was raised at flex_attention_compiled:

VLLM_ATTENTION_BACKEND=FLEX_ATTENTION python examples/offline_inference/basic/generate.py --model google/gemma-3-270m-it --enforce-eager
(EngineCore_DP0 pid=3629297) ERROR 09-16 23:53:40 [core.py:720]     self.impl.forward(self,
(EngineCore_DP0 pid=3629297) ERROR 09-16 23:53:40 [core.py:720]   File "/home/mozf/develop-projects/vllm/vllm/v1/attention/backends/flex_attention.py", line 805, in forward
(EngineCore_DP0 pid=3629297) ERROR 09-16 23:53:40 [core.py:720]     out = flex_attention_compiled(
(EngineCore_DP0 pid=3629297) ERROR 09-16 23:53:40 [core.py:720]           ^^^^^^^^^^^^^^^^^^^^^^^^
...
(EngineCore_DP0 pid=3629297) ERROR 09-16 23:53:40 [core.py:720]   File "/home/mozf/develop-projects/vllm/.venv/lib/python3.12/site-packages/torch/_subclasses/meta_utils.py", line 312, in describe_tensor
(EngineCore_DP0 pid=3629297) ERROR 09-16 23:53:40 [core.py:720]     storage = self.describe_storage(t.untyped_storage(), trace=trace)
(EngineCore_DP0 pid=3629297) ERROR 09-16 23:53:40 [core.py:720]                                     ^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=3629297) ERROR 09-16 23:53:40 [core.py:720] torch._dynamo.exc.InternalTorchDynamoError: RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/home/mozf/develop-projects/vllm/.venv/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py", line 903, in create_block_mask

Full logs:
gemma3_logs.txt

I suspect the dynamic block_mask creation caused this issue when using hybrid allocator, let me investigate then.

@noooop
Copy link
Collaborator

noooop commented Sep 16, 2025

I just ran google/gemma-3-1b-it, but got another error different from above logs, which was raised at flex_attention_compiled:

+1 PTAL #24872 (comment)

(I only saw the keyword ‘compile’, maybe it’s not related.)

@Isotr0py
Copy link
Member Author

Oh, seems we have to disable hybrid allocator when using FlexAttention 😢:

VLLM_ATTENTION_BACKEND=FLEX_ATTENTION python examples/offline_inference/basic/generate.py --model google/gemma-3-270m-it --enforce-eager --disable-hybrid-kv-cache-manager
# VLLM_ATTENTION_BACKEND=FLEX_ATTENTION python examples/offline_inference/basic/generate.py --model google/gemma-3-270m-it --disable-hybrid-kv-cach
e-manager
INFO 09-16 16:25:27 [__init__.py:216] Automatically detected platform cuda.
INFO 09-16 16:25:29 [utils.py:328] non-default args: {'num_redundant_experts': None, 'eplb_window_size': None, 'eplb_step_interval': None, 'eplb_log_balancedness': None, 'enable_lora': None, 'disable_hybrid_kv_cache_manager': True, 'model': 'google/gemma-3-270m-it'}
config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.35k/1.35k [00:00<00:00, 8.93MB/s]
INFO 09-16 16:25:37 [__init__.py:741] Resolved architecture: Gemma3ForCausalLM
`torch_dtype` is deprecated! Use `dtype` instead!
WARNING 09-16 16:25:37 [__init__.py:2825] Your device 'Tesla T4' (with compute capability 7.5) doesn't support torch.bfloat16. Falling back to torch.float32 for compatibility.
INFO 09-16 16:25:37 [__init__.py:2870] Upcasting torch.bfloat16 to torch.float32.
INFO 09-16 16:25:37 [__init__.py:1814] Using max model len 32768
INFO 09-16 16:25:37 [scheduler.py:222] Chunked prefill is enabled with max_num_batched_tokens=8192.
WARNING 09-16 16:25:37 [__init__.py:3465] Turing devices tensor cores do not support float32 matmul. To workaround this limitation, vLLM will set 'ieee' input precision for chunked prefill triton kernels.
tokenizer_config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.16M/1.16M [00:00<00:00, 17.3MB/s]
tokenizer.model: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.69M/4.69M [00:00<00:00, 49.6MB/s]
tokenizer.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33.4M/33.4M [00:00<00:00, 170MB/s]
added_tokens.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35.0/35.0 [00:00<00:00, 471kB/s]
special_tokens_map.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 662/662 [00:00<00:00, 9.61MB/s]
chat_template.jinja: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.53k/1.53k [00:00<00:00, 21.1MB/s]
generation_config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 173/173 [00:00<00:00, 1.40MB/s]
(EngineCore_DP0 pid=3851) INFO 09-16 16:25:41 [core.py:654] Waiting for init message from front-end.
(EngineCore_DP0 pid=3851) INFO 09-16 16:25:41 [core.py:76] Initializing a V1 LLM engine (v0.10.2rc3.dev114+g08369289a) with config: model='google/gemma-3-270m-it', speculative_config=None, tokenizer='google/gemma-3-270m-it', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float32, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=google/gemma-3-270m-it, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output","vllm.mamba_mixer2","vllm.mamba_mixer","vllm.short_conv","vllm.linear_attention","vllm.plamo2_mamba_mixer","vllm.gdn_attention"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"cudagraph_mode":1,"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"pass_config":{},"max_capture_size":512,"local_cache_dir":null}
(EngineCore_DP0 pid=3851) ERROR 09-16 16:25:43 [fa_utils.py:57] Cannot use FA version 2 is not supported due to FA2 is only supported on devices with compute capability >= 8
[W916 16:25:53.602992035 socket.cpp:200] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W916 16:26:03.613730436 socket.cpp:200] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W916 16:26:03.614490178 ProcessGroupNCCL.cpp:981] Warning: TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated. (function operator())
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:03 [parallel_state.py:1165] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
(EngineCore_DP0 pid=3851) WARNING 09-16 16:26:03 [topk_topp_sampler.py:69] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:03 [gpu_model_runner.py:2338] Starting to load model google/gemma-3-270m-it...
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:04 [gpu_model_runner.py:2370] Loading model from scratch...
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:04 [cuda.py:307] Using FlexAttention backend on V1 engine.
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:04 [weight_utils.py:348] Using model weights format ['*.safetensors']
model.safetensors: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 536M/536M [00:02<00:00, 223MB/s]
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:07 [weight_utils.py:369] Time spent downloading weights for google/gemma-3-270m-it: 2.675552 seconds
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:07 [weight_utils.py:406] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.54it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.54it/s]
(EngineCore_DP0 pid=3851) 
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:08 [default_loader.py:268] Loading weights took 0.67 seconds
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:08 [gpu_model_runner.py:2392] Model loading took 1.0662 GiB and 3.933256 seconds
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:14 [backends.py:539] Using cache directory: /root/.cache/vllm/torch_compile_cache/be4e8d6d30/rank_0_0/backbone for vLLM's torch.compile
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:14 [backends.py:550] Dynamo bytecode transform time: 5.36 s
(EngineCore_DP0 pid=3851) [rank0]:W0916 16:26:15.717000 3851 torch/_inductor/utils.py:1436] [0/0] Not enough SMs to use max_autotune_gemm mode
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:19 [backends.py:194] Cache the graph for dynamic shape for later use
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:37 [backends.py:215] Compiling a graph for dynamic shape takes 22.74 s
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:39 [monitor.py:34] torch.compile takes 28.11 s in total
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:40 [gpu_worker.py:276] Available KV cache memory: 9.91 GiB
(EngineCore_DP0 pid=3851) WARNING 09-16 16:26:40 [kv_cache_utils.py:1054] Hybrid KV cache manager is disabled for this hybrid model, This means we do not enable any optimizations for saving KV cache memory (e.g., dropping the KV cache outside the sliding window). The compute of layers like sliding window is still saved.
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:40 [kv_cache_utils.py:864] GPU KV cache size: 288,720 tokens
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:40 [kv_cache_utils.py:868] Maximum concurrency for 32,768 tokens per request: 8.81x
Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|████████████████████████████████████████████████████████████████████████████████| 67/67 [00:02<00:00, 29.67it/s]
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:43 [gpu_model_runner.py:3118] Graph capturing finished in 3 secs, took 0.33 GiB
(EngineCore_DP0 pid=3851) INFO 09-16 16:26:43 [core.py:218] init engine (profile, create kv cache, warmup model) took 35.20 seconds
INFO 09-16 16:26:46 [llm.py:285] Supported_tasks: ['generate']
INFO 09-16 16:26:46 [__init__.py:36] No IOProcessor plugins requested by the model
WARNING 09-16 16:26:46 [__init__.py:1694] Default sampling parameters have been overridden by the model's Hugging Face generation config recommended from the model creator. If this is not intended, please relaunch vLLM instance with `--generation-config vllm`.
Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 1202.32it/s]
Processed prompts: 100%|████████████████████████████████████████████████████████████████████| 4/4 [00:13<00:00,  3.39s/it, est. speed input: 1.92 toks/s, output: 4.71 toks/s]
--------------------------------------------------
Prompt: 'Hello, my name is'
Generated text: " [Your Name] and I'm a [Your Profession/Title]. I"
--------------------------------------------------
Prompt: 'The president of the United States is'
Generated text: ' expected to meet with the nation’s top diplomats and leaders to discuss upcoming issues'
--------------------------------------------------
Prompt: 'The capital of France is'
Generated text: ' Paris. The capital of France is the seat of the First State of France.'
--------------------------------------------------
Prompt: 'The future of AI is'
Generated text: ' being defined by the ongoing and rapidly evolving development of deep learning. Deep learning,'

@noooop
Copy link
Collaborator

noooop commented Sep 16, 2025

By the way,
I added ppl tests in #24485 to verifying model implementation
Added ci_envs in #24630 for convenient local testing

e.g.

VLLM_CI_DTYPE=float32 pytest tests/models/language/pooling_mteb_test/test_st_projector.py::test_embed_models_mteb[model_info1] 
VLLM_CI_DTYPE=float32 pytest tests/models/language/pooling_mteb_test/test_gte.py::test_rerank_models_mteb[model_info0]

Welcome to use and fix what you need!

@drisspg
Copy link
Contributor

drisspg commented Sep 17, 2025

But as _build_block_mask_direct didn't implement sliding window, the block mask is still that of full attention. This implementation can have correct output but will be slow as it doesn't skip the computation of sliding window. Is my understanding correct? (If this is true, I think we need a warning)

              attn_metadata.block_mask = (
                    attn_metadata._build_block_mask_direct())

The direct build path should skip non intra window blocks if the page table correctly evicts those blocks

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
Trying to some follow-ups:

  1. https://github.com/vllm-project/vllm/pull/24089/files#r2341783626
  2. support hybrid allocator
  3. support real sliding window when using direct build

@heheda12345 heheda12345 enabled auto-merge (squash) September 21, 2025 03:48
@heheda12345 heheda12345 merged commit cf56cf7 into vllm-project:main Sep 21, 2025
42 of 43 checks passed
@Isotr0py Isotr0py deleted the flex-sliding-window branch September 21, 2025 05:09
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
…ct#24089)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: charlifu <charlifu@amd.com>
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…ct#24089)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…ct#24089)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: FlexAttention does not support sliding window yet.

7 participants