Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions tests/test_chunked_prefill_pooler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import os
import torch
from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.layers.pooler import LastPool

def test_chunked_prefill_pooler(monkeypatch):
"""Test chunked prefill for pooling models with LastPool."""
model_id = "sentence-transformers/all-MiniLM-L6-v2"
config = ModelConfig(model_id)
config.pooler_config = PoolerConfig(pooling_type="LAST")

# Use a closure to track chunks
chunks = []

class DummyPooler(LastPool):
def __call__(self, hidden_states, pooling_cursor):
chunks.append(hidden_states)
return super().__call__(hidden_states, pooling_cursor)

monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler)

# Set environment variables for Windows compatibility
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["CUDA_VISIBLE_DEVICES"] = "" # Force CPU usage on Windows

# Set chunking parameters to force chunked prefill
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, where is it being configured for prompt prefill?

from vllm.entrypoints.llm import LLM

# Note: Chunked prefill is automatically handled by vLLM internally based on the model size and prompt
llm = LLM(
model=model_id,

Check failure on line 34 in tests/test_chunked_prefill_pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

tests/test_chunked_prefill_pooler.py:34:81: E501 Line too long (106 > 80)
runner="pooling",
override_pooler_config=PoolerConfig(pooling_type="LAST"),
trust_remote_code=True,
tensor_parallel_size=1,
enforce_eager=True, # Helps with Windows compatibility
)

prompt = "This is a test prompt for chunked prefill."
output = llm.embed([prompt])

# Check that DummyPooler was called and chunks were received
assert len(chunks) > 0

# Verify the sum of the lengths of the chunks matches the prompt length
total_chunk_len = sum(len(chunk) for chunk in chunks)
assert total_chunk_len == len(prompt)

# Compare with non-chunked output
llm_non_chunked = LLM(
model=model_id,
runner="pooling",
override_pooler_config=PoolerConfig(pooling_type="LAST"),
trust_remote_code=True,
tensor_parallel_size=1,
enforce_eager=True,
)
output_non_chunked = llm_non_chunked.embed([prompt])

# Compare embeddings with tolerance for floating point differences
assert torch.allclose(torch.tensor(output[0]), torch.tensor(output_non_chunked[0]), atol=1e-6)

# Note: For faster tests, use a smaller model like 'Qwen/Qwen3-Embedding-0.6'.
# To override the pooler, you can set trust_remote_code=True and use auto_map in hf_config.

def test_chunked_prefill_prefix_caching(monkeypatch):
"""Test chunked prefill with prefix caching for pooling models."""

Check failure on line 70 in tests/test_chunked_prefill_pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

tests/test_chunked_prefill_pooler.py:70:81: E501 Line too long (82 > 80)
model_id = "sentence-transformers/all-MiniLM-L6-v2"

Check failure on line 71 in tests/test_chunked_prefill_pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

tests/test_chunked_prefill_pooler.py:71:81: E501 Line too long (95 > 80)
config = ModelConfig(model_id)
config.pooler_config = PoolerConfig(pooling_type="LAST")

chunks = []

class DummyPooler(LastPool):
def __call__(self, hidden_states, pooling_cursor):
chunks.append(hidden_states)
return super().__call__(hidden_states, pooling_cursor)

monkeypatch.setattr("vllm.model_executor.layers.pooler.LastPool", DummyPooler)

# Set environment variables for Windows compatibility
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["CUDA_VISIBLE_DEVICES"] = "" # Force CPU usage on Windows

from vllm.entrypoints.llm import LLM

# Note: Chunked prefill is automatically handled by vLLM internally based on the model size and prompt
llm = LLM(
model=model_id,
runner="pooling",
override_pooler_config=PoolerConfig(pooling_type="LAST"),
trust_remote_code=True,
tensor_parallel_size=1,
enforce_eager=True, # Helps with Windows compatibility

Check failure on line 97 in tests/test_chunked_prefill_pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

tests/test_chunked_prefill_pooler.py:97:81: E501 Line too long (106 > 80)
)

prefix = "Common prefix. "
prompt1 = prefix + "First input."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prefix has to be bigger than a paged attention block, otherwise it's not cached.

prompt2 = prefix + "Second input."

llm.embed([prompt1])
chunks.clear()
llm.embed([prompt2])

# Only the last hidden states should be checked (those going into the pooler)
# Verify the sum of the lengths of the chunks matches the prompt length minus prefix
total_chunk_len = sum(len(chunk) for chunk in chunks)
assert total_chunk_len == len(prompt2) - len(prefix)
Loading