From c5fe531b2b84ef9ba91e3dacf9dca5b544752427 Mon Sep 17 00:00:00 2001 From: Ayush Singh Date: Tue, 2 Sep 2025 21:58:44 +0530 Subject: [PATCH 1/5] test_chunked_prefill_pooler --- tests/test_chunked_prefill_pooler.py | 48 ++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/test_chunked_prefill_pooler.py diff --git a/tests/test_chunked_prefill_pooler.py b/tests/test_chunked_prefill_pooler.py new file mode 100644 index 000000000000..58c2a9b66e2a --- /dev/null +++ b/tests/test_chunked_prefill_pooler.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from vllm.config import ModelConfig, PoolerConfig +from vllm.model_executor.layers.pooler import PoolingType, LastPool +from vllm.engine.llm_engine import LLMEngine + +class DummyPooler(LastPool): + def __init__(self): + super().__init__() + self.chunks = [] + def __call__(self, hidden_states, pooling_cursor): + self.chunks.append(hidden_states) + return super().__call__(hidden_states, pooling_cursor) + +def test_chunked_prefill_pooler(monkeypatch): + """Test chunked prefill for pooling models with LastPool.""" + model_id = "BAAI/bge-multilingual-gemma2" + config = ModelConfig(model_id) + pooler = DummyPooler() + config.pooler_config = PoolerConfig(pooling_type="LAST") + # Patch LLMEngine to use DummyPooler + monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler) + engine = LLMEngine(config) + prompt = "This is a test prompt for chunked prefill." + output = engine.generate([prompt], max_tokens=8, enable_chunked_prefill=True) + # Check that chunks were received + assert len(pooler.chunks) > 1 + # Compare with non-chunked output + output_non_chunked = engine.generate([prompt], max_tokens=8, enable_chunked_prefill=False) + assert output[0] == output_non_chunked[0] + +def test_chunked_prefill_prefix_caching(monkeypatch): + """Test chunked prefill with prefix caching for pooling models.""" + model_id = "BAAI/bge-multilingual-gemma2" + config = ModelConfig(model_id) + pooler = DummyPooler() + config.pooler_config = PoolerConfig(pooling_type="LAST") + monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler) + engine = LLMEngine(config) + prefix = "Common prefix. " + prompt1 = prefix + "First input." + prompt2 = prefix + "Second input." + engine.generate([prompt1], max_tokens=8, enable_chunked_prefill=True) + output2 = engine.generate([prompt2], max_tokens=8, enable_chunked_prefill=True) + # The pooler should see hidden states of length (total - prefix length) + assert all(len(chunk) <= len(prompt2) - len(prefix) for chunk in pooler.chunks) From f4bc32442c68de665756d1855cb4ed9c49155fb4 Mon Sep 17 00:00:00 2001 From: Ayush Singh Date: Tue, 2 Sep 2025 23:10:52 +0530 Subject: [PATCH 2/5] test_chunked_prefill_pooler --- tests/test_chunked_prefill_pooler.py | 46 +++++++++++++++------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/tests/test_chunked_prefill_pooler.py b/tests/test_chunked_prefill_pooler.py index 58c2a9b66e2a..b6604d891916 100644 --- a/tests/test_chunked_prefill_pooler.py +++ b/tests/test_chunked_prefill_pooler.py @@ -1,48 +1,52 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import pytest from vllm.config import ModelConfig, PoolerConfig -from vllm.model_executor.layers.pooler import PoolingType, LastPool -from vllm.engine.llm_engine import LLMEngine - -class DummyPooler(LastPool): - def __init__(self): - super().__init__() - self.chunks = [] - def __call__(self, hidden_states, pooling_cursor): - self.chunks.append(hidden_states) - return super().__call__(hidden_states, pooling_cursor) +from vllm.model_executor.layers.pooler import LastPool +from vllm.v1.engine.llm_engine import LLMEngine def test_chunked_prefill_pooler(monkeypatch): """Test chunked prefill for pooling models with LastPool.""" model_id = "BAAI/bge-multilingual-gemma2" config = ModelConfig(model_id) - pooler = DummyPooler() config.pooler_config = PoolerConfig(pooling_type="LAST") - # Patch LLMEngine to use DummyPooler + # Use a closure to track chunks + chunks = [] + class DummyPooler(LastPool): + def __call__(self, hidden_states, pooling_cursor): + chunks.append(hidden_states) + return super().__call__(hidden_states, pooling_cursor) monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler) - engine = LLMEngine(config) + # Set chunking parameters to force chunked prefill + engine = LLMEngine(config, enable_chunked_prefill=True, long_prefill_token_threshold=1) prompt = "This is a test prompt for chunked prefill." - output = engine.generate([prompt], max_tokens=8, enable_chunked_prefill=True) + output = engine.embed([prompt]) # Check that chunks were received - assert len(pooler.chunks) > 1 + assert len(chunks) > 1 # Compare with non-chunked output - output_non_chunked = engine.generate([prompt], max_tokens=8, enable_chunked_prefill=False) + engine_non_chunked = LLMEngine(config, enable_chunked_prefill=False) + output_non_chunked = engine_non_chunked.embed([prompt]) assert output[0] == output_non_chunked[0] def test_chunked_prefill_prefix_caching(monkeypatch): """Test chunked prefill with prefix caching for pooling models.""" model_id = "BAAI/bge-multilingual-gemma2" config = ModelConfig(model_id) - pooler = DummyPooler() config.pooler_config = PoolerConfig(pooling_type="LAST") + chunks = [] + class DummyPooler(LastPool): + def __call__(self, hidden_states, pooling_cursor): + chunks.append(hidden_states) + return super().__call__(hidden_states, pooling_cursor) monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler) - engine = LLMEngine(config) + engine = LLMEngine(config, enable_chunked_prefill=True, long_prefill_token_threshold=1) prefix = "Common prefix. " prompt1 = prefix + "First input." prompt2 = prefix + "Second input." - engine.generate([prompt1], max_tokens=8, enable_chunked_prefill=True) - output2 = engine.generate([prompt2], max_tokens=8, enable_chunked_prefill=True) + engine.embed([prompt1]) + chunks.clear() + engine.embed([prompt2]) # The pooler should see hidden states of length (total - prefix length) - assert all(len(chunk) <= len(prompt2) - len(prefix) for chunk in pooler.chunks) + assert all(len(chunk) <= len(prompt2) - len(prefix) for chunk in chunks) From b1749c9e005bf52ca715c10d2c7e4277913091b9 Mon Sep 17 00:00:00 2001 From: Ayush Singh Date: Thu, 4 Sep 2025 03:05:59 +0530 Subject: [PATCH 3/5] suggestd changes --- tests/test_chunked_prefill_pooler.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/test_chunked_prefill_pooler.py b/tests/test_chunked_prefill_pooler.py index b6604d891916..dd35153c19bf 100644 --- a/tests/test_chunked_prefill_pooler.py +++ b/tests/test_chunked_prefill_pooler.py @@ -23,12 +23,17 @@ def __call__(self, hidden_states, pooling_cursor): engine = LLMEngine(config, enable_chunked_prefill=True, long_prefill_token_threshold=1) prompt = "This is a test prompt for chunked prefill." output = engine.embed([prompt]) - # Check that chunks were received - assert len(chunks) > 1 + # Check that DummyPooler was called and chunks were received + assert len(chunks) > 0 + # Verify the sum of the lengths of the chunks matches the prompt length + total_chunk_len = sum(len(chunk) for chunk in chunks) + assert total_chunk_len == len(prompt) # Compare with non-chunked output engine_non_chunked = LLMEngine(config, enable_chunked_prefill=False) output_non_chunked = engine_non_chunked.embed([prompt]) assert output[0] == output_non_chunked[0] + # Note: For faster tests, use a smaller model like 'Qwen/Qwen3-Embedding-0.6'. + # To override the pooler, you can set trust_remote_code=True and use auto_map in hf_config. def test_chunked_prefill_prefix_caching(monkeypatch): """Test chunked prefill with prefix caching for pooling models.""" @@ -48,5 +53,7 @@ def __call__(self, hidden_states, pooling_cursor): engine.embed([prompt1]) chunks.clear() engine.embed([prompt2]) - # The pooler should see hidden states of length (total - prefix length) - assert all(len(chunk) <= len(prompt2) - len(prefix) for chunk in chunks) + # Only the last hidden states should be checked (those going into the pooler) + # Verify the sum of the lengths of the chunks matches the prompt length minus prefix + total_chunk_len = sum(len(chunk) for chunk in chunks) + assert total_chunk_len == len(prompt2) - len(prefix) From d7084c0849266ec0f0e986d69e9ecd016d018c47 Mon Sep 17 00:00:00 2001 From: Ayush Singh Date: Fri, 5 Sep 2025 14:37:03 +0530 Subject: [PATCH 4/5] new changes --- tests/test_chunked_prefill_pooler.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/tests/test_chunked_prefill_pooler.py b/tests/test_chunked_prefill_pooler.py index dd35153c19bf..1f958cf29fe8 100644 --- a/tests/test_chunked_prefill_pooler.py +++ b/tests/test_chunked_prefill_pooler.py @@ -5,11 +5,10 @@ import pytest from vllm.config import ModelConfig, PoolerConfig from vllm.model_executor.layers.pooler import LastPool -from vllm.v1.engine.llm_engine import LLMEngine def test_chunked_prefill_pooler(monkeypatch): """Test chunked prefill for pooling models with LastPool.""" - model_id = "BAAI/bge-multilingual-gemma2" + model_id = "sentence-transformers/all-MiniLM-L6-v2" config = ModelConfig(model_id) config.pooler_config = PoolerConfig(pooling_type="LAST") # Use a closure to track chunks @@ -20,24 +19,25 @@ def __call__(self, hidden_states, pooling_cursor): return super().__call__(hidden_states, pooling_cursor) monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler) # Set chunking parameters to force chunked prefill - engine = LLMEngine(config, enable_chunked_prefill=True, long_prefill_token_threshold=1) + from vllm.entrypoints.llm import LLM + llm = LLM(model=model_id, runner="pooling", override_pooler_config=PoolerConfig(pooling_type="LAST"), trust_remote_code=True, device="cpu") prompt = "This is a test prompt for chunked prefill." - output = engine.embed([prompt]) + output = llm.embed([prompt]) # Check that DummyPooler was called and chunks were received assert len(chunks) > 0 # Verify the sum of the lengths of the chunks matches the prompt length total_chunk_len = sum(len(chunk) for chunk in chunks) assert total_chunk_len == len(prompt) # Compare with non-chunked output - engine_non_chunked = LLMEngine(config, enable_chunked_prefill=False) - output_non_chunked = engine_non_chunked.embed([prompt]) + llm_non_chunked = LLM(model=model_id, runner="pooling", override_pooler_config=PoolerConfig(pooling_type="LAST"), trust_remote_code=True) + output_non_chunked = llm_non_chunked.embed([prompt]) assert output[0] == output_non_chunked[0] # Note: For faster tests, use a smaller model like 'Qwen/Qwen3-Embedding-0.6'. # To override the pooler, you can set trust_remote_code=True and use auto_map in hf_config. def test_chunked_prefill_prefix_caching(monkeypatch): """Test chunked prefill with prefix caching for pooling models.""" - model_id = "BAAI/bge-multilingual-gemma2" + model_id = "sentence-transformers/all-MiniLM-L6-v2" config = ModelConfig(model_id) config.pooler_config = PoolerConfig(pooling_type="LAST") chunks = [] @@ -46,13 +46,19 @@ def __call__(self, hidden_states, pooling_cursor): chunks.append(hidden_states) return super().__call__(hidden_states, pooling_cursor) monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler) - engine = LLMEngine(config, enable_chunked_prefill=True, long_prefill_token_threshold=1) + from vllm.entrypoints.llm import LLM + llm = LLM( + model=model_id, + runner="pooling", + override_pooler_config=PoolerConfig(pooling_type="LAST"), + trust_remote_code=True, +) prefix = "Common prefix. " prompt1 = prefix + "First input." prompt2 = prefix + "Second input." - engine.embed([prompt1]) + llm.embed([prompt1]) chunks.clear() - engine.embed([prompt2]) + llm.embed([prompt2]) # Only the last hidden states should be checked (those going into the pooler) # Verify the sum of the lengths of the chunks matches the prompt length minus prefix total_chunk_len = sum(len(chunk) for chunk in chunks) From 0c4e290c252cc123e8e4ee2a54c300f81a04fc2c Mon Sep 17 00:00:00 2001 From: Ayush Singh Date: Thu, 11 Sep 2025 14:21:23 +0530 Subject: [PATCH 5/5] updated pr --- tests/test_chunked_prefill_pooler.py | 68 +++++++++++++++++++++++----- 1 file changed, 57 insertions(+), 11 deletions(-) diff --git a/tests/test_chunked_prefill_pooler.py b/tests/test_chunked_prefill_pooler.py index 1f958cf29fe8..ca02373bac79 100644 --- a/tests/test_chunked_prefill_pooler.py +++ b/tests/test_chunked_prefill_pooler.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - - import pytest +import os +import torch from vllm.config import ModelConfig, PoolerConfig from vllm.model_executor.layers.pooler import LastPool @@ -11,27 +11,58 @@ def test_chunked_prefill_pooler(monkeypatch): model_id = "sentence-transformers/all-MiniLM-L6-v2" config = ModelConfig(model_id) config.pooler_config = PoolerConfig(pooling_type="LAST") + # Use a closure to track chunks chunks = [] + class DummyPooler(LastPool): def __call__(self, hidden_states, pooling_cursor): chunks.append(hidden_states) return super().__call__(hidden_states, pooling_cursor) + monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler) + + # Set environment variables for Windows compatibility + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + os.environ["CUDA_VISIBLE_DEVICES"] = "" # Force CPU usage on Windows + # Set chunking parameters to force chunked prefill from vllm.entrypoints.llm import LLM - llm = LLM(model=model_id, runner="pooling", override_pooler_config=PoolerConfig(pooling_type="LAST"), trust_remote_code=True, device="cpu") + + # Note: Chunked prefill is automatically handled by vLLM internally based on the model size and prompt + llm = LLM( + model=model_id, + runner="pooling", + override_pooler_config=PoolerConfig(pooling_type="LAST"), + trust_remote_code=True, + tensor_parallel_size=1, + enforce_eager=True, # Helps with Windows compatibility + ) + prompt = "This is a test prompt for chunked prefill." output = llm.embed([prompt]) + # Check that DummyPooler was called and chunks were received assert len(chunks) > 0 + # Verify the sum of the lengths of the chunks matches the prompt length total_chunk_len = sum(len(chunk) for chunk in chunks) assert total_chunk_len == len(prompt) + # Compare with non-chunked output - llm_non_chunked = LLM(model=model_id, runner="pooling", override_pooler_config=PoolerConfig(pooling_type="LAST"), trust_remote_code=True) + llm_non_chunked = LLM( + model=model_id, + runner="pooling", + override_pooler_config=PoolerConfig(pooling_type="LAST"), + trust_remote_code=True, + tensor_parallel_size=1, + enforce_eager=True, + ) output_non_chunked = llm_non_chunked.embed([prompt]) - assert output[0] == output_non_chunked[0] + + # Compare embeddings with tolerance for floating point differences + assert torch.allclose(torch.tensor(output[0]), torch.tensor(output_non_chunked[0]), atol=1e-6) + # Note: For faster tests, use a smaller model like 'Qwen/Qwen3-Embedding-0.6'. # To override the pooler, you can set trust_remote_code=True and use auto_map in hf_config. @@ -40,26 +71,41 @@ def test_chunked_prefill_prefix_caching(monkeypatch): model_id = "sentence-transformers/all-MiniLM-L6-v2" config = ModelConfig(model_id) config.pooler_config = PoolerConfig(pooling_type="LAST") + chunks = [] + class DummyPooler(LastPool): def __call__(self, hidden_states, pooling_cursor): chunks.append(hidden_states) return super().__call__(hidden_states, pooling_cursor) + monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler) + + # Set environment variables for Windows compatibility + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + os.environ["CUDA_VISIBLE_DEVICES"] = "" # Force CPU usage on Windows + from vllm.entrypoints.llm import LLM + + # Note: Chunked prefill is automatically handled by vLLM internally based on the model size and prompt llm = LLM( - model=model_id, - runner="pooling", - override_pooler_config=PoolerConfig(pooling_type="LAST"), - trust_remote_code=True, -) + model=model_id, + runner="pooling", + override_pooler_config=PoolerConfig(pooling_type="LAST"), + trust_remote_code=True, + tensor_parallel_size=1, + enforce_eager=True, # Helps with Windows compatibility + ) + prefix = "Common prefix. " prompt1 = prefix + "First input." prompt2 = prefix + "Second input." + llm.embed([prompt1]) chunks.clear() llm.embed([prompt2]) + # Only the last hidden states should be checked (those going into the pooler) # Verify the sum of the lengths of the chunks matches the prompt length minus prefix total_chunk_len = sum(len(chunk) for chunk in chunks) - assert total_chunk_len == len(prompt2) - len(prefix) + assert total_chunk_len == len(prompt2) - len(prefix) \ No newline at end of file