From 7a9a647adf620b180e24f7e3cdc5782db4cf6a36 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 16 Oct 2025 09:26:46 +0800 Subject: [PATCH 01/11] add plugin pooling task Signed-off-by: wang.yuqi --- .../prithvi_geospatial_mae.py | 2 +- .../prithvi_geospatial_mae_io_processor.py | 3 +-- .../online_serving/prithvi_geospatial_mae.py | 1 - .../multimodal/pooling/test_prithvi_mae.py | 2 +- .../test_io_processor_plugins.py | 4 +--- vllm/entrypoints/openai/protocol.py | 7 +----- vllm/entrypoints/openai/serving_pooling.py | 24 +++---------------- vllm/model_executor/layers/pooler.py | 12 ++++++++++ vllm/model_executor/models/terratorch.py | 6 ++--- vllm/pooling_params.py | 5 ++++ vllm/tasks.py | 4 +++- 11 files changed, 30 insertions(+), 40 deletions(-) diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index 2c73ed6aa608..56c1196c2ed7 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -63,7 +63,7 @@ def run(self, input_data, location_coords): } prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} - outputs = self.model.encode(prompt, use_tqdm=False) + outputs = self.model.encode(prompt, pooling_task="plugin", use_tqdm=False) return outputs[0].outputs.data diff --git a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py index 6c47b5715438..8abe524d4954 100644 --- a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py +++ b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py @@ -40,10 +40,9 @@ def main(): model_impl="terratorch", ) - pooling_params = PoolingParams(task="token_classify", activation=False) pooler_output = llm.encode( img_prompt, - pooling_params=pooling_params, + pooling_task="plugin" ) output = pooler_output[0].outputs diff --git a/examples/online_serving/prithvi_geospatial_mae.py b/examples/online_serving/prithvi_geospatial_mae.py index 611a7cbc89fa..3fa2c31196cd 100644 --- a/examples/online_serving/prithvi_geospatial_mae.py +++ b/examples/online_serving/prithvi_geospatial_mae.py @@ -34,7 +34,6 @@ def main(): }, "priority": 0, "model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", - "softmax": False, } ret = requests.post(server_endpoint, json=request_payload_url) diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py index 62154b083487..a54b0ba8461d 100644 --- a/tests/models/multimodal/pooling/test_prithvi_mae.py +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -39,7 +39,7 @@ def _run_test( max_num_seqs=32, default_torch_num_threads=1, ) as vllm_model: - vllm_model.llm.encode(prompt, pooling_task="token_classify") + vllm_model.llm.encode(prompt, pooling_task="plugin") MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"] diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py index 936f27fb69bc..76c65d3babb7 100644 --- a/tests/plugins_tests/test_io_processor_plugins.py +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -93,8 +93,6 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): out_data_format="b64_json", ) - pooling_params = PoolingParams(activation=False) - with vllm_runner( model_name, runner="pooling", @@ -108,7 +106,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): io_processor_plugin="prithvi_to_tiff", ) as llm_runner: pooler_output = llm_runner.get_llm().encode( - img_prompt, pooling_params=pooling_params, pooling_task="token_classify" + img_prompt, pooling_task="plugin" ) output = pooler_output[0].outputs diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 5b8a118280da..4ba34032c427 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1678,11 +1678,6 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]): if the served model does not use priority scheduling. """ data: T - """ - When using plugins IOProcessor plugins, the actual input is processed - by the plugin itself. Hence, we use a generic type for the request data - """ - activation: bool = False embed_dtype: str = Field( default="float32", @@ -1693,7 +1688,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]): ) def to_pooling_params(self): - return PoolingParams(task="token_classify", activation=self.activation) + return PoolingParams() class IOProcessorResponse(OpenAIBaseModel, Generic[T]): diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 7a27348da35b..56b5a712b943 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -2,16 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio -import base64 import time from collections.abc import AsyncGenerator from typing import Final, Literal, cast import jinja2 -import numpy as np -import torch from fastapi import Request -from typing_extensions import assert_never from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption @@ -34,29 +30,13 @@ from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.utils import _validate_truncation_size from vllm.logger import init_logger -from vllm.outputs import PoolingOutput, PoolingRequestOutput +from vllm.outputs import PoolingRequestOutput from vllm.tasks import SupportedTask from vllm.utils.async_utils import merge_async_iterators logger = init_logger(__name__) -def _get_data( - output: PoolingOutput, - encoding_format: Literal["float", "base64"], -) -> list[float] | str: - if encoding_format == "float": - return output.data.tolist() - elif encoding_format == "base64": - # Force to use float32 for base64 encoding - # to match the OpenAI python client behavior - pt_float32 = output.data.to(dtype=torch.float32) - pooling_bytes = np.array(pt_float32, dtype="float32").tobytes() - return base64.b64encode(pooling_bytes).decode("utf-8") - - assert_never(encoding_format) - - class OpenAIServingPooling(OpenAIServing): def __init__( self, @@ -185,6 +165,8 @@ async def create_pooling( pooling_task = "token_embed" elif "token_classify" in self.supported_tasks: pooling_task = "token_classify" + elif "plugin" in self.supported_tasks: + pooling_task = "plugin" else: return self.create_error_response( f"pooling_task must be one of {self.supported_tasks}." diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 84e176f0ea89..920d50dfbd3c 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -414,6 +414,18 @@ def forward( raise NotImplementedError +class DummyPooler(Pooler): + def get_supported_tasks(self) -> Set[PoolingTask]: + raise {"plugin", "score"} + + def forward( + self, + hidden_states: list[torch.Tensor] | torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + return hidden_states + + class PoolerHead(nn.Module): def __init__(self, activation: PoolerActivation) -> None: super().__init__() diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index 0252705c62b1..e799e41e2c38 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -34,7 +34,7 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import DispatchPooler, Pooler +from vllm.model_executor.layers.pooler import DispatchPooler, DummyPooler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.multimodal import MULTIMODAL_REGISTRY @@ -249,9 +249,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - {"token_classify": Pooler.for_token_classify(pooler_config)} - ) + self.pooler = DispatchPooler({"plugin": DummyPooler()}) def get_input_embeddings( self, diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index c6dff6e01c1d..090d92414465 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -84,6 +84,11 @@ def verify( msg = f"You cannot overwrite {self.task=!r} with {task=!r}!" raise ValueError(msg) + # plugin task uses io_processor.parse_request to verify inputs, + # skipping PoolingParams verify + if self.task == "plugin": + return + # NOTE: Task validation needs to done against the model instance, # which is not available in model config. So, it's not included # in this method diff --git a/vllm/tasks.py b/vllm/tasks.py index 6551444d1710..b02cde74c12a 100644 --- a/vllm/tasks.py +++ b/vllm/tasks.py @@ -5,7 +5,9 @@ GenerationTask = Literal["generate", "transcription"] GENERATION_TASKS = get_args(GenerationTask) -PoolingTask = Literal["embed", "classify", "score", "token_embed", "token_classify"] +PoolingTask = Literal[ + "embed", "classify", "score", "token_embed", "token_classify", "plugin" +] POOLING_TASKS = get_args(PoolingTask) SupportedTask = Literal[GenerationTask, PoolingTask] From cefd6ebd12535bd46feeb66d9f148ad1b6ff0f7a Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 16 Oct 2025 09:31:18 +0800 Subject: [PATCH 02/11] Update vllm/model_executor/layers/pooler.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: wang.yuqi --- vllm/model_executor/layers/pooler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 920d50dfbd3c..3211d5b9662c 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -416,7 +416,7 @@ def forward( class DummyPooler(Pooler): def get_supported_tasks(self) -> Set[PoolingTask]: - raise {"plugin", "score"} + return {"plugin", "score"} def forward( self, From 75ad7dd343333b67d660b1207105269733d2a803 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 16 Oct 2025 09:36:54 +0800 Subject: [PATCH 03/11] add plugin pooling task Signed-off-by: wang.yuqi --- .../prithvi_geospatial_mae_io_processor.py | 6 +----- tests/plugins_tests/test_io_processor_plugins.py | 5 +---- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py index 8abe524d4954..92144773ba14 100644 --- a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py +++ b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py @@ -6,7 +6,6 @@ import torch from vllm import LLM -from vllm.pooling_params import PoolingParams # This example shows how to perform an offline inference that generates # multimodal data. In this specific case this example will take a geotiff @@ -40,10 +39,7 @@ def main(): model_impl="terratorch", ) - pooler_output = llm.encode( - img_prompt, - pooling_task="plugin" - ) + pooler_output = llm.encode(img_prompt, pooling_task="plugin") output = pooler_output[0].outputs print(output) diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py index 76c65d3babb7..49a3eda6da39 100644 --- a/tests/plugins_tests/test_io_processor_plugins.py +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -9,7 +9,6 @@ from vllm.config import VllmConfig from vllm.entrypoints.openai.protocol import IOProcessorResponse from vllm.plugins.io_processors import get_io_processor -from vllm.pooling_params import PoolingParams MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" @@ -105,9 +104,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): model_impl="terratorch", io_processor_plugin="prithvi_to_tiff", ) as llm_runner: - pooler_output = llm_runner.get_llm().encode( - img_prompt, pooling_task="plugin" - ) + pooler_output = llm_runner.get_llm().encode(img_prompt, pooling_task="plugin") output = pooler_output[0].outputs # verify the output is formatted as expected for this plugin From 970fabd3ee321abc72cc6a4bae8839c3d7afe943 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Wed, 22 Oct 2025 18:49:47 +0800 Subject: [PATCH 04/11] fix Signed-off-by: wang.yuqi --- vllm/entrypoints/openai/serving_pooling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index b0f3acadd9e2..140ce1896ceb 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -9,6 +9,7 @@ import jinja2 from fastapi import Request +from typing_extensions import assert_never from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption From db63ee3e6bbc764a5aea9e47e4797d92e4df5f5c Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Wed, 22 Oct 2025 10:21:42 +0000 Subject: [PATCH 05/11] Added validation/generation of Pooling/Sampling parameters with IOProcessor plugin Signed-off-by: Christian Pinto --- vllm/entrypoints/llm.py | 31 +++++++++++++--------- vllm/entrypoints/openai/api_server.py | 7 ++++- vllm/entrypoints/openai/protocol.py | 3 --- vllm/entrypoints/openai/serving_pooling.py | 5 +++- vllm/plugins/io_processors/interface.py | 9 +++++++ 5 files changed, 37 insertions(+), 18 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e82db693c92d..8d4f44bac27e 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1014,19 +1014,6 @@ def encode( "pooling model." ) - if pooling_task not in self.supported_tasks: - raise ValueError(f"pooling_task must be one of {self.supported_tasks}.") - - if pooling_params is None: - # Use default pooling params. - pooling_params = PoolingParams() - - for param in as_iter(pooling_params): - param.verify(pooling_task, model_config) - # for backwards compatibility - if truncate_prompt_tokens is not None: - param.truncate_prompt_tokens = truncate_prompt_tokens - io_processor_prompt = False if isinstance(prompts, dict) and "data" in prompts: io_processor_prompt = True @@ -1044,6 +1031,24 @@ def encode( # obtain the actual model prompts from the pre-processor prompts = self.io_processor.pre_process(prompt=validated_prompt) + if io_processor_prompt: + pooling_params = self.io_processor.validate_or_generate_params( + pooling_params + ) + else: + if pooling_params is None: + # Use default pooling params. + pooling_params = PoolingParams() + + if pooling_task not in self.supported_tasks: + raise ValueError(f"pooling_task must be one of {self.supported_tasks}.") + + for param in as_iter(pooling_params): + param.verify(pooling_task, model_config) + # for backwards compatibility + if truncate_prompt_tokens is not None: + param.truncate_prompt_tokens = truncate_prompt_tokens + self._validate_and_add_requests( prompts=prompts, params=pooling_params, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f84a530fd004..cc257f5ea6cf 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1787,7 +1787,12 @@ async def init_app_state( log_error_stack=args.log_error_stack, ) ) - if ("token_embed" in supported_tasks or "token_classify" in supported_tasks) + if ( + any( + task in supported_tasks + for task in ["token_embed", "token_classify", "plugin"] + ) + ) else None ) state.openai_serving_embedding = ( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index ca70faf62d62..147f897fb779 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1726,9 +1726,6 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]): ), ) - def to_pooling_params(self): - return PoolingParams() - class IOProcessorResponse(OpenAIBaseModel, Generic[T]): request_id: str | None = None diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 140ce1896ceb..1e0832d30379 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -161,7 +161,10 @@ async def create_pooling( # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] try: - pooling_params = request.to_pooling_params() + if is_io_processor_request: + pooling_params = self.io_processor.validate_or_generate_params() + else: + pooling_params = request.to_pooling_params() if "token_embed" in self.supported_tasks: pooling_task = "token_embed" diff --git a/vllm/plugins/io_processors/interface.py b/vllm/plugins/io_processors/interface.py index 81e077d5bdac..f66b2c4347d2 100644 --- a/vllm/plugins/io_processors/interface.py +++ b/vllm/plugins/io_processors/interface.py @@ -9,6 +9,8 @@ from vllm.entrypoints.openai.protocol import IOProcessorResponse from vllm.inputs.data import PromptType from vllm.outputs import PoolingRequestOutput +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams IOProcessorInput = TypeVar("IOProcessorInput") IOProcessorOutput = TypeVar("IOProcessorOutput") @@ -63,6 +65,13 @@ async def post_process_async( def parse_request(self, request: Any) -> IOProcessorInput: raise NotImplementedError + def validate_or_generate_params( + self, params: SamplingParams | PoolingParams | None = None + ) -> SamplingParams | PoolingParams: + if params: + return params + return PoolingParams() + @abstractmethod def output_to_response( self, plugin_output: IOProcessorOutput From fe5ac87ca86304085051f337f70b9d18f5ba0c78 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Wed, 22 Oct 2025 12:55:54 +0000 Subject: [PATCH 06/11] Dummy pooler returns one less dimension. FIxed prithvi plugin accordingly Signed-off-by: Christian Pinto --- .../prithvi_io_processor/prithvi_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py index 772824cdde8f..5614f19d1a4f 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py @@ -368,9 +368,9 @@ def post_process( out_format = "b64_json" for output in model_output: - y_hat = output.outputs.data.argmax(dim=1) + y_hat = output.outputs.data.argmax(dim=0) pred = torch.nn.functional.interpolate( - y_hat.unsqueeze(1).float(), + y_hat[None, None, ...].float(), size=self.img_size, mode="nearest", ) From c38faaed06f3b3d6ba140c1632fddf09b3f557a7 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Wed, 22 Oct 2025 14:37:45 +0000 Subject: [PATCH 07/11] Solving pre-commit failures Signed-off-by: Christian Pinto --- vllm/entrypoints/llm.py | 16 +++++++++++++--- vllm/entrypoints/openai/serving_pooling.py | 3 +++ vllm/pooling_params.py | 3 ++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 8d4f44bac27e..8078dcd7fe8b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1016,6 +1016,7 @@ def encode( io_processor_prompt = False if isinstance(prompts, dict) and "data" in prompts: + assert self.io_processor is not None io_processor_prompt = True if self.io_processor is None: raise ValueError( @@ -1032,9 +1033,18 @@ def encode( prompts = self.io_processor.pre_process(prompt=validated_prompt) if io_processor_prompt: - pooling_params = self.io_processor.validate_or_generate_params( - pooling_params - ) + assert self.io_processor is not None + if is_list_of(pooling_params, PoolingParams): + validated_pooling_params: list[PoolingParams] = [] + for param in as_iter(pooling_params): + validated_pooling_params.append( + self.io_processor.validate_or_generate_params(param) + ) + pooling_params = validated_pooling_params + else: + pooling_params = self.io_processor.validate_or_generate_params( + pooling_params + ) else: if pooling_params is None: # Use default pooling params. diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 1e0832d30379..7d33aead4daa 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -162,6 +162,9 @@ async def create_pooling( generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] try: if is_io_processor_request: + assert self.io_processor is not None and isinstance( + request, IOProcessorRequest + ) pooling_params = self.io_processor.validate_or_generate_params() else: pooling_params = request.to_pooling_params() diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 090d92414465..b46f41141091 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -7,7 +7,7 @@ import msgspec from vllm.sampling_params import RequestOutputKind -from vllm.tasks import PoolingTask +from vllm.tasks import POOLING_TASKS, PoolingTask if TYPE_CHECKING: from vllm.config import ModelConfig, PoolerConfig @@ -78,6 +78,7 @@ def clone(self) -> "PoolingParams": def verify( self, task: PoolingTask, model_config: Optional["ModelConfig"] = None ) -> None: + assert task in POOLING_TASKS if self.task is None: self.task = task elif self.task != task: From c92753dcac82290e274627c1b7dd91e1aca58c27 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Wed, 22 Oct 2025 14:43:43 +0000 Subject: [PATCH 08/11] Solving pre-commit failures Signed-off-by: Christian Pinto --- vllm/entrypoints/llm.py | 1 + vllm/entrypoints/openai/protocol.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 8078dcd7fe8b..3c7112c8c2a3 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1042,6 +1042,7 @@ def encode( ) pooling_params = validated_pooling_params else: + assert not is_list_of(pooling_params, PoolingParams) pooling_params = self.io_processor.validate_or_generate_params( pooling_params ) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 147f897fb779..ca70faf62d62 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1726,6 +1726,9 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]): ), ) + def to_pooling_params(self): + return PoolingParams() + class IOProcessorResponse(OpenAIBaseModel, Generic[T]): request_id: str | None = None From 5f0f625096087f33e839b87fb0bccdcc7aeeac07 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Wed, 22 Oct 2025 15:30:25 +0000 Subject: [PATCH 09/11] Still trying to please the type checker Signed-off-by: Christian Pinto --- vllm/entrypoints/llm.py | 2 +- vllm/entrypoints/openai/serving_pooling.py | 3 ++- vllm/pooling_params.py | 3 +-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 3c7112c8c2a3..6fedea235ab2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1042,7 +1042,7 @@ def encode( ) pooling_params = validated_pooling_params else: - assert not is_list_of(pooling_params, PoolingParams) + assert not isinstance(pooling_params, Sequence) pooling_params = self.io_processor.validate_or_generate_params( pooling_params ) diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 7d33aead4daa..568896ccbf1b 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -32,7 +32,7 @@ from vllm.entrypoints.utils import _validate_truncation_size from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput -from vllm.tasks import SupportedTask +from vllm.tasks import PoolingTask, SupportedTask from vllm.utils.async_utils import merge_async_iterators from vllm.utils.serial_utils import ( EmbedDType, @@ -169,6 +169,7 @@ async def create_pooling( else: pooling_params = request.to_pooling_params() + pooling_task: PoolingTask if "token_embed" in self.supported_tasks: pooling_task = "token_embed" elif "token_classify" in self.supported_tasks: diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index b46f41141091..090d92414465 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -7,7 +7,7 @@ import msgspec from vllm.sampling_params import RequestOutputKind -from vllm.tasks import POOLING_TASKS, PoolingTask +from vllm.tasks import PoolingTask if TYPE_CHECKING: from vllm.config import ModelConfig, PoolerConfig @@ -78,7 +78,6 @@ def clone(self) -> "PoolingParams": def verify( self, task: PoolingTask, model_config: Optional["ModelConfig"] = None ) -> None: - assert task in POOLING_TASKS if self.task is None: self.task = task elif self.task != task: From 71564fa7e5a03ebd319f1ab470a9c4cd8087c2c0 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Thu, 23 Oct 2025 07:29:37 +0000 Subject: [PATCH 10/11] Fixed documentation and updated examples requirements Signed-off-by: Christian Pinto --- docs/design/io_processor_plugins.md | 20 +++++++++++++++---- .../prithvi_geospatial_mae_io_processor.py | 7 ++++--- .../online_serving/prithvi_geospatial_mae.py | 6 +++--- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/docs/design/io_processor_plugins.md b/docs/design/io_processor_plugins.md index 1873566d0981..170b7750cacc 100644 --- a/docs/design/io_processor_plugins.md +++ b/docs/design/io_processor_plugins.md @@ -13,7 +13,6 @@ IOProcessorInput = TypeVar("IOProcessorInput") IOProcessorOutput = TypeVar("IOProcessorOutput") class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): - def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config @@ -49,13 +48,26 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): request_id: str | None = None, **kwargs, ) -> IOProcessorOutput: - collected_output = [item async for i, item in model_output] + # We cannot guarantee outputs are returned in the same order they were + # fed to vLLM. + # Let's sort them by id before post_processing + sorted_output = sorted( + [(i, item) async for i, item in model_output], key=lambda output: output[0] + ) + collected_output = [output[1] for output in sorted_output] return self.post_process(collected_output, request_id, **kwargs) @abstractmethod def parse_request(self, request: Any) -> IOProcessorInput: raise NotImplementedError + def validate_or_generate_params( + self, params: SamplingParams | PoolingParams | None = None + ) -> SamplingParams | PoolingParams: + if params: + return params + return PoolingParams() + @abstractmethod def output_to_response( self, plugin_output: IOProcessorOutput @@ -66,10 +78,10 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods. The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference. The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output. - +The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters. The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py). -An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples. +An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples. ## Using an IO Processor plugin diff --git a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py index 6879634aa8b1..b8637b89e08f 100644 --- a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py +++ b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py @@ -11,8 +11,9 @@ # multimodal data. In this specific case this example will take a geotiff # image as input, process it using the multimodal data processor, and # perform inference. -# Requirement - install plugin at: -# https://github.com/christian-pinto/prithvi_io_processor_plugin +# Requirements: +# - install TerraTorch v1.1 (or later): +# pip install terratorch>=v1.1 def main(): @@ -35,7 +36,7 @@ def main(): # to avoid the model going OOM. # The maximum number depends on the available GPU memory max_num_seqs=32, - io_processor_plugin="prithvi_to_tiff", + io_processor_plugin="terratorch_segmentation", model_impl="terratorch", enable_mm_embeds=True, ) diff --git a/examples/online_serving/prithvi_geospatial_mae.py b/examples/online_serving/prithvi_geospatial_mae.py index 64556b846728..a6246999c14d 100644 --- a/examples/online_serving/prithvi_geospatial_mae.py +++ b/examples/online_serving/prithvi_geospatial_mae.py @@ -11,14 +11,14 @@ # image as input, process it using the multimodal data processor, and # perform inference. # Requirements : -# - install plugin at: -# https://github.com/christian-pinto/prithvi_io_processor_plugin +# - install TerraTorch v1.1 (or later): +# pip install terratorch>=v1.1 # - start vllm in serving mode with the below args # --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM' # --model-impl terratorch # --task embed --trust-remote-code # --skip-tokenizer-init --enforce-eager -# --io-processor-plugin prithvi_to_tiff +# --io-processor-plugin terratorch_segmentation # --enable-mm-embeds From fd885c216ee2f9a57e00fbd3b3a774e2bf44c28a Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Thu, 23 Oct 2025 12:18:08 +0000 Subject: [PATCH 11/11] Small fixes after review Signed-off-by: Christian Pinto --- docs/design/io_processor_plugins.md | 4 +--- vllm/entrypoints/llm.py | 1 - vllm/plugins/io_processors/interface.py | 4 +--- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/docs/design/io_processor_plugins.md b/docs/design/io_processor_plugins.md index 170b7750cacc..fb64a7bb9c8f 100644 --- a/docs/design/io_processor_plugins.md +++ b/docs/design/io_processor_plugins.md @@ -64,9 +64,7 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): def validate_or_generate_params( self, params: SamplingParams | PoolingParams | None = None ) -> SamplingParams | PoolingParams: - if params: - return params - return PoolingParams() + return params or PoolingParams() @abstractmethod def output_to_response( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 5f5cb4466d21..290acf4afb52 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1026,7 +1026,6 @@ def encode( io_processor_prompt = False if isinstance(prompts, dict) and "data" in prompts: - assert self.io_processor is not None io_processor_prompt = True if self.io_processor is None: raise ValueError( diff --git a/vllm/plugins/io_processors/interface.py b/vllm/plugins/io_processors/interface.py index f66b2c4347d2..e0488e48614d 100644 --- a/vllm/plugins/io_processors/interface.py +++ b/vllm/plugins/io_processors/interface.py @@ -68,9 +68,7 @@ def parse_request(self, request: Any) -> IOProcessorInput: def validate_or_generate_params( self, params: SamplingParams | PoolingParams | None = None ) -> SamplingParams | PoolingParams: - if params: - return params - return PoolingParams() + return params or PoolingParams() @abstractmethod def output_to_response(