Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 53 additions & 28 deletions docs/models/pooling_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ Each pooling model in vLLM supports one or more of these tasks according to
[Pooler.get_supported_tasks][vllm.model_executor.layers.pooler.Pooler.get_supported_tasks],
enabling the corresponding APIs:

| Task | APIs |
|------------|--------------------|
| `encode` | `encode` |
| `embed` | `embed`, `score`\* |
| `classify` | `classify` |
| `score` | `score` |
| Task | APIs |
|------------|--------------------------------------|
| `encode` | `LLM.reward(...)` |
| `embed` | `LLM.embed(...)`, `LLM.score(...)`\* |
| `classify` | `LLM.classify(...)` |
| `score` | `LLM.score(...)` |

\* The `score` API falls back to `embed` task if the model does not support `score` task.
\* The `LLM.score(...)` API falls back to `embed` task if the model does not support `score` task.

### Pooler Configuration

Expand All @@ -66,11 +66,11 @@ you can override some of its attributes via the `--override-pooler-config` optio
If the model has been converted via `--convert` (see above),
the pooler assigned to each task has the following attributes by default:

| Task | Pooling Type | Normalization | Softmax |
|------------|----------------|---------------|---------|
| `encode` | `ALL` | ❌ | ❌ |
| `embed` | `LAST` | ✅︎ | ❌ |
| `classify` | `LAST` | ❌ | ✅︎ |
| Task | Pooling Type | Normalization | Softmax |
|------------|--------------|---------------|---------|
| `reward` | `ALL` | ❌ | ❌ |
| `embed` | `LAST` | ✅︎ | ❌ |
| `classify` | `LAST` | ❌ | ✅︎ |

When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models,
its Sentence Transformers configuration file (`modules.json`) takes priority over the model's defaults.
Expand All @@ -83,21 +83,6 @@ which takes priority over both the model's and Sentence Transformers's defaults.
The [LLM][vllm.LLM] class provides various methods for offline inference.
See [configuration][configuration] for a list of options when initializing the model.

### `LLM.encode`

The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM.
It returns the extracted hidden states directly, which is useful for reward models.

```python
from vllm import LLM

llm = LLM(model="Qwen/Qwen2.5-Math-RM-72B", runner="pooling")
(output,) = llm.encode("Hello, my name is")

data = output.outputs.data
print(f"Data: {data!r}")
```

### `LLM.embed`

The [embed][vllm.LLM.embed] method outputs an embedding vector for each prompt.
Expand All @@ -106,7 +91,7 @@ It is primarily designed for embedding models.
```python
from vllm import LLM

llm = LLM(model="intfloat/e5-mistral-7b-instruct", runner="pooling")
llm = LLM(model="intfloat/e5-small", runner="pooling")
(output,) = llm.embed("Hello, my name is")

embeds = output.outputs.embedding
Expand Down Expand Up @@ -154,6 +139,46 @@ print(f"Score: {score}")

A code example can be found here: <gh-file:examples/offline_inference/basic/score.py>

### `LLM.reward`

The [reward][vllm.LLM.reward] method is available to all reward models in vLLM.
It returns the extracted hidden states directly.

```python
from vllm import LLM

llm = LLM(model="internlm/internlm2-1_8b-reward", runner="pooling", trust_remote_code=True)
(output,) = llm.reward("Hello, my name is")

data = output.outputs.data
print(f"Data: {data!r}")
```

A code example can be found here: <gh-file:examples/offline_inference/basic/reward.py>

### `LLM.encode`

The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM.
It returns the extracted hidden states directly.

!!! note
Please use one of the more specific methods or set the task directly when using `LLM.encode`:

- For embeddings, use `LLM.embed(...)` or `pooling_task="embed"`.
- For classification logits, use `LLM.classify(...)` or `pooling_task="classify"`.
- For rewards, use `LLM.reward(...)` or `pooling_task="reward"`.
- For similarity scores, use `LLM.score(...)`.

```python
from vllm import LLM

llm = LLM(model="intfloat/e5-small", runner="pooling")
(output,) = llm.encode("Hello, my name is", pooling_task="embed")

data = output.outputs.data
print(f"Data: {data!r}")
```

## Online Serving

Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs:
Expand Down
3 changes: 1 addition & 2 deletions examples/offline_inference/basic/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ def parse_args():
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(
model="intfloat/e5-mistral-7b-instruct",
model="intfloat/e5-small",
runner="pooling",
enforce_eager=True,
max_model_len=1024,
)
return parser.parse_args()

Expand Down
53 changes: 53 additions & 0 deletions examples/offline_inference/basic/reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from argparse import Namespace

from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser


def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(
model="internlm/internlm2-1_8b-reward",
runner="pooling",
enforce_eager=True,
max_model_len=1024,
trust_remote_code=True,
)
return parser.parse_args()


def main(args: Namespace):
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

# Create an LLM.
# You should pass runner="pooling" for reward models
llm = LLM(**vars(args))

# Generate rewards. The output is a list of PoolingRequestOutput.
outputs = llm.reward(prompts)

# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
rewards = output.outputs.data
rewards_trimmed = (
(str(rewards[:16])[:-1] + ", ...]") if len(rewards) > 16 else rewards
)
print(f"Prompt: {prompt!r} \nReward: {rewards_trimmed} (size={len(rewards)})")
print("-" * 60)


if __name__ == "__main__":
args = parse_args()
main(args)
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,10 @@ def encode(self, prompts: list[str]) -> list[list[float]]:
req_outputs = self.llm.encode(prompts)
return [req_output.outputs.data for req_output in req_outputs]

def reward(self, prompts: list[str]) -> list[list[float]]:
req_outputs = self.llm.reward(prompts)
return [req_output.outputs.data for req_output in req_outputs]

def score(
self,
text_1: Union[str, list[str]],
Expand Down
2 changes: 1 addition & 1 deletion tests/models/language/pooling/test_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_prm_models(
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")

with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.encode(math_step_prompts)
vllm_outputs = vllm_model.reward(math_step_prompts)

with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model:
hf_model = step_reward_patch_hf_model(hf_model)
Expand Down
6 changes: 3 additions & 3 deletions tests/models/language/pooling/test_truncation_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_smaller_truncation_size(vllm_runner,

with vllm_runner(model_name, runner="pooling",
max_model_len=max_model_len) as vllm_model:
vllm_output = vllm_model.llm.encode(
vllm_output = vllm_model.llm.embed(
input_str, truncate_prompt_tokens=truncate_prompt_tokens)

prompt_tokens = vllm_output[0].prompt_token_ids
Expand All @@ -43,7 +43,7 @@ def test_max_truncation_size(vllm_runner,

with vllm_runner(model_name, runner="pooling",
max_model_len=max_model_len) as vllm_model:
vllm_output = vllm_model.llm.encode(
vllm_output = vllm_model.llm.embed(
input_str, truncate_prompt_tokens=truncate_prompt_tokens)

prompt_tokens = vllm_output[0].prompt_token_ids
Expand All @@ -61,7 +61,7 @@ def test_bigger_truncation_size(vllm_runner,
model_name, runner="pooling",
max_model_len=max_model_len) as vllm_model:

llm_output = vllm_model.llm.encode(
llm_output = vllm_model.llm.embed(
input_str, truncate_prompt_tokens=truncate_prompt_tokens)

assert llm_output == f"""truncate_prompt_tokens value
Expand Down
60 changes: 59 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ def encode(
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
pooling_task: PoolingTask = "encode",
pooling_task: Optional[PoolingTask] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
"""Apply pooling to the hidden states corresponding to the input
Expand Down Expand Up @@ -1069,6 +1069,25 @@ def encode(
considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter.
"""
if pooling_task is None:
if "embed" in self.supported_tasks:
pooling_task = "embed"
else:
pooling_task = "encode"

logger.warning_once(
"`LLM.encode` is currently using `pooling_task = %s`.\n"
"Please use one of the more specific methods or set the "
"task directly when using `LLM.encode`:\n"
" - For embeddings, use `LLM.embed(...)` "
"or `pooling_task=\"embed\"`.\n"
" - For classification logits, use `LLM.classify(...)` "
"or `pooling_task=\"classify\"`.\n"
" - For rewards, use `LLM.reward(...)` "
"or `pooling_task=\"reward\"`\n"
" - For similarity scores, use `LLM.score(...)`.",
pooling_task)

model_config = self.llm_engine.model_config
runner_type = model_config.runner_type
if runner_type != "pooling":
Expand Down Expand Up @@ -1207,6 +1226,45 @@ def classify(

return [ClassificationRequestOutput.from_base(item) for item in items]

def reward(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
) -> list[PoolingRequestOutput]:
"""
Generate rewards for each prompt.

Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompts.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
Returns:
A list of `PoolingRequestOutput` objects containing the
pooled hidden states in the same order as the input prompts.
"""

return self.encode(
prompts,
use_tqdm=use_tqdm,
lora_request=lora_request,
pooling_params=pooling_params,
truncate_prompt_tokens=truncate_prompt_tokens,
pooling_task="encode",
)

def _embedding_score(
self,
tokenizer: AnyTokenizer,
Expand Down