-
-
Notifications
You must be signed in to change notification settings - Fork 11.2k
test_chunked_prefill_pooler refrencing #23436 #24114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
|
@maxdebayser plz review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds tests for chunked prefill and prefix caching with pooling models. The tests are a good addition, but they have a critical flaw in how they are structured. The tests create an instance of a DummyPooler and then assert on its state, but this instance is separate from the one created and used by the LLMEngine. This means the assertions will always be on an empty list of chunks, and the tests are not actually verifying the intended behavior. I've provided a suggestion to refactor the tests to correctly capture and assert on the state of the pooler used by the engine. This involves defining the DummyPooler class within each test function to leverage closures for state tracking, which is a more robust and self-contained testing pattern.
tests/test_chunked_prefill_pooler.py
Outdated
| class DummyPooler(LastPool): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.chunks = [] | ||
| def __call__(self, hidden_states, pooling_cursor): | ||
| self.chunks.append(hidden_states) | ||
| return super().__call__(hidden_states, pooling_cursor) | ||
|
|
||
| def test_chunked_prefill_pooler(monkeypatch): | ||
| """Test chunked prefill for pooling models with LastPool.""" | ||
| model_id = "BAAI/bge-multilingual-gemma2" | ||
| config = ModelConfig(model_id) | ||
| pooler = DummyPooler() | ||
| config.pooler_config = PoolerConfig(pooling_type="LAST") | ||
| # Patch LLMEngine to use DummyPooler | ||
| monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler) | ||
| engine = LLMEngine(config) | ||
| prompt = "This is a test prompt for chunked prefill." | ||
| output = engine.generate([prompt], max_tokens=8, enable_chunked_prefill=True) | ||
| # Check that chunks were received | ||
| assert len(pooler.chunks) > 1 | ||
| # Compare with non-chunked output | ||
| output_non_chunked = engine.generate([prompt], max_tokens=8, enable_chunked_prefill=False) | ||
| assert output[0] == output_non_chunked[0] | ||
|
|
||
| def test_chunked_prefill_prefix_caching(monkeypatch): | ||
| """Test chunked prefill with prefix caching for pooling models.""" | ||
| model_id = "BAAI/bge-multilingual-gemma2" | ||
| config = ModelConfig(model_id) | ||
| pooler = DummyPooler() | ||
| config.pooler_config = PoolerConfig(pooling_type="LAST") | ||
| monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler) | ||
| engine = LLMEngine(config) | ||
| prefix = "Common prefix. " | ||
| prompt1 = prefix + "First input." | ||
| prompt2 = prefix + "Second input." | ||
| engine.generate([prompt1], max_tokens=8, enable_chunked_prefill=True) | ||
| output2 = engine.generate([prompt2], max_tokens=8, enable_chunked_prefill=True) | ||
| # The pooler should see hidden states of length (total - prefix length) | ||
| assert all(len(chunk) <= len(prompt2) - len(prefix) for chunk in pooler.chunks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of the tests has a critical flaw. The DummyPooler instance created in each test function (e.g., pooler = DummyPooler()) is not the same instance that the LLMEngine uses internally. When LLMEngine is initialized, it creates its own new instance of DummyPooler because of the monkeypatch. As a result, the assertions on pooler.chunks are always checking an empty list, and the tests are not actually verifying the intended behavior.
To fix this, the DummyPooler class should be defined inside each test function. This allows it to capture a chunks list from the test function's scope via a closure. This is a more robust and self-contained way to test this kind of behavior, as it avoids shared state between tests and correctly captures the data from the instance used by the engine.
Additionally, the state of the captured chunks should be managed carefully between multiple engine.generate() calls within a single test to ensure assertions are made on the correct set of chunks.
import pytest
from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.layers.pooler import LastPool
from vllm.engine.llm_engine import LLMEngine
def test_chunked_prefill_pooler(monkeypatch):
"""Test chunked prefill for pooling models with LastPool."""
model_id = "BAAI/bge-multilingual-gemma2"
config = ModelConfig(model_id)
config.pooler_config = PoolerConfig(pooling_type="LAST")
# This list will be populated by the DummyPooler instance inside the engine.
chunks = []
class DummyPooler(LastPool):
def __call__(self, hidden_states, pooling_cursor):
chunks.append(hidden_states)
return super().__call__(hidden_states, pooling_cursor)
monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler)
# Test with chunked prefill
engine = LLMEngine(config)
prompt = "This is a test prompt for chunked prefill."
output = engine.generate([prompt], max_tokens=8, enable_chunked_prefill=True)
# Check that chunks were received
assert len(chunks) > 1
# Compare with non-chunked output
# A new engine is created to ensure a clean state for the pooler.
engine_non_chunked = LLMEngine(config)
output_non_chunked = engine_non_chunked.generate(
[prompt], max_tokens=8, enable_chunked_prefill=False
)
assert output[0] == output_non_chunked[0]
def test_chunked_prefill_prefix_caching(monkeypatch):
"""Test chunked prefill with prefix caching for pooling models."""
model_id = "BAAI/bge-multilingual-gemma2"
config = ModelConfig(model_id)
config.pooler_config = PoolerConfig(pooling_type="LAST")
# This list will be populated by the DummyPooler instance inside the engine.
chunks = []
class DummyPooler(LastPool):
def __call__(self, hidden_states, pooling_cursor):
chunks.append(hidden_states)
return super().__call__(hidden_states, pooling_cursor)
monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler)
engine = LLMEngine(config)
prefix = "Common prefix. "
prompt1 = prefix + "First input."
prompt2 = prefix + "Second input."
# First call to populate the prefix cache
engine.generate([prompt1], max_tokens=8, enable_chunked_prefill=True)
# Clear chunks from the first call before the second call
chunks.clear()
# Second call should use the prefix cache
engine.generate([prompt2], max_tokens=8, enable_chunked_prefill=True)
# The pooler should see hidden states of length (total - prefix length)
assert all(len(chunk) <= len(prompt2) - len(prefix) for chunk in chunks)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this gemini review has a good point. Since the model runner runs in a separate process, the monkey patch in the pytest process should have no effect.
tests/test_chunked_prefill_pooler.py
Outdated
| import pytest | ||
| from vllm.config import ModelConfig, PoolerConfig | ||
| from vllm.model_executor.layers.pooler import PoolingType, LastPool | ||
| from vllm.engine.llm_engine import LLMEngine |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vllm.engine.llm_engine.LLMEngine is the old V0 engine class and we've just removed all the pooling code from V0. You can import vllm.v1.engine.llm_engine.LLMEngine instead. But I think you can also use the LLM entrypoint class. In this way you don't need to worry about building the pooling request correctly from the prompt.
tests/test_chunked_prefill_pooler.py
Outdated
| monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler) | ||
| engine = LLMEngine(config) | ||
| prompt = "This is a test prompt for chunked prefill." | ||
| output = engine.generate([prompt], max_tokens=8, enable_chunked_prefill=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you mean LLM.generate() because the LLMEngine class doesn't have a generate method. Also, generate is for token generation, it's not for pooling. embed is a good LLM method to use (LLM.encode() is more low level)
You could also use LLMEngine.add_request, but then you have to build the pooling request, so LLM.embed()` is easier.
tests/test_chunked_prefill_pooler.py
Outdated
| monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler) | ||
| engine = LLMEngine(config) | ||
| prompt = "This is a test prompt for chunked prefill." | ||
| output = engine.generate([prompt], max_tokens=8, enable_chunked_prefill=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enable_chunked_prefill is actually part of the initialization parameters, it's not set on a per request basis. You might also have to change long_prefill_token_threshold to force chunking with small prompt sizes.
| monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler) | ||
| engine = LLMEngine(config) | ||
| prefix = "Common prefix. " | ||
| prompt1 = prefix + "First input." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The prefix has to be bigger than a paged attention block, otherwise it's not cached.
|
@noooop , the |
|
Thanks for the feedback... @maxdebayser |
|
Wait, don't trust the gemini code suggestion as is. It might be wrong. The important thing is to verify whether the DummyPooler is called. |
|
Hmmm...How would you recommend verifying that DummyPooler’s [call] is actually invoked? Is there a preferred way to do this in the vLLM test setup? |
|
@maxdebayser went thraugh the suggested changes...hope this fixes the remaining issues :) |
how about use Qwen/Qwen3-Embedding-0.6 for testing,It’s smaller and more people use it. If we only test the last hidden states, I think we can change the chunk size by adjusting max_num_batched_tokens and long_prefill_token_threshold, and verify whether the last hidden states are similar. If we want to compare each hidden states, I think this problem is equivalent to Impl chunked prefill + all pooling. Let’s implement it in a separate PR. But is it really necessary to compare each hidden states? |
|
Thanks for the suggestion @noooop plz take a look at the pr i made and suggest the changes where you feel like are important |
Which plan do you prefer? |
tests/test_chunked_prefill_pooler.py
Outdated
| # Check that chunks were received | ||
| assert len(chunks) > 1 | ||
| # Compare with non-chunked output | ||
| engine_non_chunked = LLMEngine(config, enable_chunked_prefill=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The LLMEngine class doesn't have a embed method. But vllm.LLM does
tests/test_chunked_prefill_pooler.py
Outdated
| # Compare with non-chunked output | ||
| engine_non_chunked = LLMEngine(config, enable_chunked_prefill=False) | ||
| output_non_chunked = engine_non_chunked.embed([prompt]) | ||
| assert output[0] == output_non_chunked[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here at the end we need to retrieve the DummyPooler and verify of the length of the sum of the chunks is the same as the prompt length
|
@noooop , I agree, we should only verify the last hidden states, which are the ones that go into the pooler. So if we replace the pooler with a mock pooler that collects the chunks, at the end we can verify whether they sum up to the prompt length. |
|
@maxdebayser @noooop what is the course of event now? |
Wait, can you summarize your understanding of both approaches? I want to understand how you're thinking about this because in my mind there is no conflict between the suggested changes and the new approach. |
Sure! I thought tracking the DummyPooler instance in the test would verify calls, but I missed that the engine creates its own instance, so my check didn’t work🥲 |
|
Yeah, to summarize:
I think for this last point you could try to override the hf_config for this model and add trust_remote_code=true and set the auto_map to load a custom class which does the monkey patching. https://huggingface.co/deepseek-ai/DeepSeek-V2.5/blob/main/config.json#L7 for inspiration. |
|
All requested changes have been implemented and the tests now follow the latest recommendations. Please let me know if there’s anything else you’d like me to address, or if the PR is ready for final review! |
|
How are you running these tests? When I run them they fail: The error is because you're using LLMEngine instead of LLM. |
@maxdebayser also prefers to only verify the last hidden states, Let's make this thing simpler. Also, I can't think of a scenario where the length of the sum of the chunks is less than the prompt length, but ends up and output embedding. (That is, it wasn't finished yet, but it was thought to be finished.) Do we really need to spend a lot of effort to get the length of the sum of the chunks? Qwen/Qwen3-Embedding-0.6 uses last pooling, I'm not 100% sure, but I think using Qwen/Qwen3-Embedding-0.6 does not need DummyPooler. Because v1 uses a multi-process architecture, the mock pooler is not simple. Please try adjusting max_num_batched_tokens (you can try extreme values such as 8, 4, or even 1. maybe 1 is too slow, but feasible for short prompts) and long_prefill_token_threshold, and verify whether the last hidden states are similar. Please manually rebase before #23398 to ensure this test can catch the issue |
|
@maxdebayser i got the following error while running on LLM
tests\test_chunked_prefill_pooler.py:50: vllm\entrypoints\llm.py:272: in init |
@noooop I agree that it's not going to happen if everything works, but just from the pooling result I think we can't know if chunking really happened. In other words, if |
|
@ArkVex , the error message you posted is truncated. Can you post the full error so that we can see what's happening? |
Sure...it was too long that's why I didnt paste it here |
|
Here you go (.venv) PS C:\Users\Lenovo\OneDrive\Desktop\vllm> pytest tests/test_chunked_prefill_pooler.py tests\test_chunked_prefill_pooler.py FF [100%] ======================================================================== FAILURES ========================================================================= monkeypatch = <_pytest.monkeypatch.MonkeyPatch object at 0x000002D632387590>
tests\test_chunked_prefill_pooler.py:23: self = <vllm.entrypoints.llm.LLM object at 0x000002D6323F0380>, model = 'sentence-transformers/all-MiniLM-L6-v2', runner = 'pooling', convert = 'auto'
E TypeError: EngineArgs.init() got an unexpected keyword argument 'device' vllm\entrypoints\llm.py:238: TypeError monkeypatch = <_pytest.monkeypatch.MonkeyPatch object at 0x000002D632408500>
tests\test_chunked_prefill_pooler.py:50: vllm\entrypoints\llm.py:272: in init self = DeviceConfig(device='', device_type='')
E RuntimeError: Device string must not be empty vllm\config_init_.py:1923: RuntimeError -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html |
|
Tip: you can paste logs and code between "```" quotes This is not the code in the this branch, can you push your latest changes so we can see them? |
|
@maxdebayser i hope this review is not getting annoying for you...I am new to open source so learning a bit things the hard way |
| chunks.append(hidden_states) | ||
| return super().__call__(hidden_states, pooling_cursor) | ||
| monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler) | ||
| # Set chunking parameters to force chunked prefill |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, where is it being configured for prompt prefill?
tests/test_chunked_prefill_pooler.py
Outdated
| monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler) | ||
| # Set chunking parameters to force chunked prefill | ||
| from vllm.entrypoints.llm import LLM | ||
| llm = LLM(model=model_id, runner="pooling", override_pooler_config=PoolerConfig(pooling_type="LAST"), trust_remote_code=True, device="cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any reason to set device=cpu?
|
No problem @ArkVex , we're here to help ;) You chose a tricky first issue to tackle. I see that you pushed some changes. I can you run the chunked prefill test already without errors? |
|
@maxdebayser thanks...but yeah i'll get this pr completed and merged |
|
@maxdebayser can you plz take a look at the new code and tell where is the screwup happening...I am consistently getting two Errors and 1 warning |
|
Wait, why did you change the model to |
|
Here, I found a way to inject the a mock pooler: import pytest
import os
import torch
import torch.nn as nn
from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.layers.pooler import LastPool
from vllm.entrypoints.llm import LLM
prompt = """
Generals gathered in their masses
Just like witches at black masses
Evil minds that plot destruction
Sorcerer of death's construction
In the fields, the bodies burning
As the war machine keeps turning
Death and hatred to mankind
Poisoning their brainwashed minds
Oh, Lord, yeah
Politicians hide themselves away
They only started the war
Why should they go out to fight?
They leave that all to the poor, yeah
Time will tell on their power minds
Making war just for fun
Treating people just like pawns in chess
Wait till their judgment day comes, yeah
Now, in darkness, world stops turning
Ashes where their bodies burning
No more war pigs have the power
Hand of God has struck the hour
Day of Judgment, God is calling
On their knees, the war pigs crawling
Begging mercies for their sins
Satan, laughing, spreads his wings
Oh, Lord, yeah
"""
def test_chunked_prefill_pooler(monkeypatch):
"""Test chunked prefill for pooling models with LastPool."""
with monkeypatch.context() as m:
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
model_id = "Qwen/Qwen3-Embedding-0.6B"
# Use a closure to track chunks
class WrapperPooler(nn.Module):
def __init__(self, pooler):
super().__init__()
self.pooler = pooler
self.chunks = []
def get_pooling_updates(self, task):
return self.pooler.get_pooling_updates(task)
def forward(self,
hidden_states,
pooling_metadata,
):
print("forward called")
self.chunks.append(hidden_states.shape[0])
return self.pooler(hidden_states, pooling_metadata)
def inject_pooler(self):
model = self.get_model()
wrapper = WrapperPooler(model.pooler)
model.pooler = wrapper
def retrieve_chunks(self):
model = self.get_model()
return model.pooler.chunks
# Set chunking parameters to force chunked prefill
# Note: Chunked prefill is automatically handled by vLLM internally based on the model size and prompt
llm = LLM(
model=model_id,
runner="pooling",
long_prefill_token_threshold=10,
tensor_parallel_size=1,
enforce_eager=True, # Helps with Windows compatibility
)
llm.llm_engine.collective_rpc(inject_pooler)
#prompt = "This is a test prompt for chunked prefill."
output = llm.embed([prompt])
chunks = llm.llm_engine.collective_rpc(retrieve_chunks)[0]
# Check that PoolerWrapper was called and chunks were received
assert len(chunks) > 1 # <------ It has to be larger than 1, not 0After this, chunks will be To compare with the second LLM instance, you need to disable chunked prefill passing |
Add CI tests for chunked prefill and prefix caching in pooling models (issue #23436).
Tests use a custom pooler to track hidden state chunks and compare outputs with and without chunking.
All new tests pass and outputs are consistent.
Test command: pytest