diff --git a/docs/design/io_processor_plugins.md b/docs/design/io_processor_plugins.md index 1873566d0981..fb64a7bb9c8f 100644 --- a/docs/design/io_processor_plugins.md +++ b/docs/design/io_processor_plugins.md @@ -13,7 +13,6 @@ IOProcessorInput = TypeVar("IOProcessorInput") IOProcessorOutput = TypeVar("IOProcessorOutput") class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): - def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config @@ -49,13 +48,24 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): request_id: str | None = None, **kwargs, ) -> IOProcessorOutput: - collected_output = [item async for i, item in model_output] + # We cannot guarantee outputs are returned in the same order they were + # fed to vLLM. + # Let's sort them by id before post_processing + sorted_output = sorted( + [(i, item) async for i, item in model_output], key=lambda output: output[0] + ) + collected_output = [output[1] for output in sorted_output] return self.post_process(collected_output, request_id, **kwargs) @abstractmethod def parse_request(self, request: Any) -> IOProcessorInput: raise NotImplementedError + def validate_or_generate_params( + self, params: SamplingParams | PoolingParams | None = None + ) -> SamplingParams | PoolingParams: + return params or PoolingParams() + @abstractmethod def output_to_response( self, plugin_output: IOProcessorOutput @@ -66,10 +76,10 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods. The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference. The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output. - +The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters. The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py). -An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples. +An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples. ## Using an IO Processor plugin diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index 8aa3fc1d3c85..b093c77c00b7 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -64,7 +64,7 @@ def run(self, input_data, location_coords): } prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} - outputs = self.model.encode(prompt, use_tqdm=False) + outputs = self.model.encode(prompt, pooling_task="plugin", use_tqdm=False) return outputs[0].outputs.data diff --git a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py index afe8f056cc5f..b8637b89e08f 100644 --- a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py +++ b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py @@ -6,14 +6,14 @@ import torch from vllm import LLM -from vllm.pooling_params import PoolingParams # This example shows how to perform an offline inference that generates # multimodal data. In this specific case this example will take a geotiff # image as input, process it using the multimodal data processor, and # perform inference. -# Requirement - install plugin at: -# https://github.com/christian-pinto/prithvi_io_processor_plugin +# Requirements: +# - install TerraTorch v1.1 (or later): +# pip install terratorch>=v1.1 def main(): @@ -36,16 +36,12 @@ def main(): # to avoid the model going OOM. # The maximum number depends on the available GPU memory max_num_seqs=32, - io_processor_plugin="prithvi_to_tiff", + io_processor_plugin="terratorch_segmentation", model_impl="terratorch", enable_mm_embeds=True, ) - pooling_params = PoolingParams(task="token_classify", activation=False) - pooler_output = llm.encode( - img_prompt, - pooling_params=pooling_params, - ) + pooler_output = llm.encode(img_prompt, pooling_task="plugin") output = pooler_output[0].outputs print(output) diff --git a/examples/online_serving/prithvi_geospatial_mae.py b/examples/online_serving/prithvi_geospatial_mae.py index fba52fe77139..a6246999c14d 100644 --- a/examples/online_serving/prithvi_geospatial_mae.py +++ b/examples/online_serving/prithvi_geospatial_mae.py @@ -11,14 +11,14 @@ # image as input, process it using the multimodal data processor, and # perform inference. # Requirements : -# - install plugin at: -# https://github.com/christian-pinto/prithvi_io_processor_plugin +# - install TerraTorch v1.1 (or later): +# pip install terratorch>=v1.1 # - start vllm in serving mode with the below args # --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM' # --model-impl terratorch # --task embed --trust-remote-code # --skip-tokenizer-init --enforce-eager -# --io-processor-plugin prithvi_to_tiff +# --io-processor-plugin terratorch_segmentation # --enable-mm-embeds @@ -35,7 +35,6 @@ def main(): }, "priority": 0, "model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", - "softmax": False, } ret = requests.post(server_endpoint, json=request_payload_url) diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py index 676076c45847..5082827962d8 100644 --- a/tests/models/multimodal/pooling/test_prithvi_mae.py +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -40,7 +40,7 @@ def _run_test( max_num_seqs=32, default_torch_num_threads=1, ) as vllm_model: - vllm_model.llm.encode(prompt, pooling_task="token_classify") + vllm_model.llm.encode(prompt, pooling_task="plugin") MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"] diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py index 772824cdde8f..5614f19d1a4f 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py @@ -368,9 +368,9 @@ def post_process( out_format = "b64_json" for output in model_output: - y_hat = output.outputs.data.argmax(dim=1) + y_hat = output.outputs.data.argmax(dim=0) pred = torch.nn.functional.interpolate( - y_hat.unsqueeze(1).float(), + y_hat[None, None, ...].float(), size=self.img_size, mode="nearest", ) diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py index 1a024118fcfe..582cf9a0711b 100644 --- a/tests/plugins_tests/test_io_processor_plugins.py +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -9,7 +9,6 @@ from vllm.config import VllmConfig from vllm.entrypoints.openai.protocol import IOProcessorResponse from vllm.plugins.io_processors import get_io_processor -from vllm.pooling_params import PoolingParams MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" @@ -94,8 +93,6 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): out_data_format="b64_json", ) - pooling_params = PoolingParams(activation=False) - with vllm_runner( model_name, runner="pooling", @@ -109,9 +106,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): model_impl="terratorch", io_processor_plugin="prithvi_to_tiff", ) as llm_runner: - pooler_output = llm_runner.get_llm().encode( - img_prompt, pooling_params=pooling_params, pooling_task="token_classify" - ) + pooler_output = llm_runner.get_llm().encode(img_prompt, pooling_task="plugin") output = pooler_output[0].outputs # verify the output is formatted as expected for this plugin diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 869861afff03..290acf4afb52 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1024,19 +1024,6 @@ def encode( "pooling model." ) - if pooling_task not in self.supported_tasks: - raise ValueError(f"pooling_task must be one of {self.supported_tasks}.") - - if pooling_params is None: - # Use default pooling params. - pooling_params = PoolingParams() - - for param in as_iter(pooling_params): - param.verify(pooling_task, model_config) - # for backwards compatibility - if truncate_prompt_tokens is not None: - param.truncate_prompt_tokens = truncate_prompt_tokens - io_processor_prompt = False if isinstance(prompts, dict) and "data" in prompts: io_processor_prompt = True @@ -1054,6 +1041,34 @@ def encode( # obtain the actual model prompts from the pre-processor prompts = self.io_processor.pre_process(prompt=validated_prompt) + if io_processor_prompt: + assert self.io_processor is not None + if is_list_of(pooling_params, PoolingParams): + validated_pooling_params: list[PoolingParams] = [] + for param in as_iter(pooling_params): + validated_pooling_params.append( + self.io_processor.validate_or_generate_params(param) + ) + pooling_params = validated_pooling_params + else: + assert not isinstance(pooling_params, Sequence) + pooling_params = self.io_processor.validate_or_generate_params( + pooling_params + ) + else: + if pooling_params is None: + # Use default pooling params. + pooling_params = PoolingParams() + + if pooling_task not in self.supported_tasks: + raise ValueError(f"pooling_task must be one of {self.supported_tasks}.") + + for param in as_iter(pooling_params): + param.verify(pooling_task, model_config) + # for backwards compatibility + if truncate_prompt_tokens is not None: + param.truncate_prompt_tokens = truncate_prompt_tokens + self._validate_and_add_requests( prompts=prompts, params=pooling_params, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index c455d5016623..29306e45bcf0 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1748,7 +1748,12 @@ async def init_app_state( log_error_stack=args.log_error_stack, ) ) - if ("token_embed" in supported_tasks or "token_classify" in supported_tasks) + if ( + any( + task in supported_tasks + for task in ["token_embed", "token_classify", "plugin"] + ) + ) else None ) state.openai_serving_embedding = ( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 6d0149960f3c..ca70faf62d62 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1707,11 +1707,6 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]): if the served model does not use priority scheduling. """ data: T - """ - When using plugins IOProcessor plugins, the actual input is processed - by the plugin itself. Hence, we use a generic type for the request data - """ - activation: bool = False encoding_format: EncodingFormat = "float" embed_dtype: EmbedDType = Field( @@ -1732,7 +1727,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]): ) def to_pooling_params(self): - return PoolingParams(task="token_classify", activation=self.activation) + return PoolingParams() class IOProcessorResponse(OpenAIBaseModel, Generic[T]): diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 5d4a638808b1..568896ccbf1b 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -32,7 +32,7 @@ from vllm.entrypoints.utils import _validate_truncation_size from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput -from vllm.tasks import SupportedTask +from vllm.tasks import PoolingTask, SupportedTask from vllm.utils.async_utils import merge_async_iterators from vllm.utils.serial_utils import ( EmbedDType, @@ -161,12 +161,21 @@ async def create_pooling( # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] try: - pooling_params = request.to_pooling_params() + if is_io_processor_request: + assert self.io_processor is not None and isinstance( + request, IOProcessorRequest + ) + pooling_params = self.io_processor.validate_or_generate_params() + else: + pooling_params = request.to_pooling_params() + pooling_task: PoolingTask if "token_embed" in self.supported_tasks: pooling_task = "token_embed" elif "token_classify" in self.supported_tasks: pooling_task = "token_classify" + elif "plugin" in self.supported_tasks: + pooling_task = "plugin" else: return self.create_error_response( f"pooling_task must be one of {self.supported_tasks}." diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index a8c66315684e..145f18f23566 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -414,6 +414,18 @@ def forward( raise NotImplementedError +class DummyPooler(Pooler): + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"plugin", "score"} + + def forward( + self, + hidden_states: list[torch.Tensor] | torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + return hidden_states + + class PoolerHead(nn.Module): def __init__(self, activation: PoolerActivation) -> None: super().__init__() diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index 0252705c62b1..e799e41e2c38 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -34,7 +34,7 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import DispatchPooler, Pooler +from vllm.model_executor.layers.pooler import DispatchPooler, DummyPooler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.multimodal import MULTIMODAL_REGISTRY @@ -249,9 +249,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - {"token_classify": Pooler.for_token_classify(pooler_config)} - ) + self.pooler = DispatchPooler({"plugin": DummyPooler()}) def get_input_embeddings( self, diff --git a/vllm/plugins/io_processors/interface.py b/vllm/plugins/io_processors/interface.py index 81e077d5bdac..e0488e48614d 100644 --- a/vllm/plugins/io_processors/interface.py +++ b/vllm/plugins/io_processors/interface.py @@ -9,6 +9,8 @@ from vllm.entrypoints.openai.protocol import IOProcessorResponse from vllm.inputs.data import PromptType from vllm.outputs import PoolingRequestOutput +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams IOProcessorInput = TypeVar("IOProcessorInput") IOProcessorOutput = TypeVar("IOProcessorOutput") @@ -63,6 +65,11 @@ async def post_process_async( def parse_request(self, request: Any) -> IOProcessorInput: raise NotImplementedError + def validate_or_generate_params( + self, params: SamplingParams | PoolingParams | None = None + ) -> SamplingParams | PoolingParams: + return params or PoolingParams() + @abstractmethod def output_to_response( self, plugin_output: IOProcessorOutput diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index c6dff6e01c1d..090d92414465 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -84,6 +84,11 @@ def verify( msg = f"You cannot overwrite {self.task=!r} with {task=!r}!" raise ValueError(msg) + # plugin task uses io_processor.parse_request to verify inputs, + # skipping PoolingParams verify + if self.task == "plugin": + return + # NOTE: Task validation needs to done against the model instance, # which is not available in model config. So, it's not included # in this method diff --git a/vllm/tasks.py b/vllm/tasks.py index 6551444d1710..b02cde74c12a 100644 --- a/vllm/tasks.py +++ b/vllm/tasks.py @@ -5,7 +5,9 @@ GenerationTask = Literal["generate", "transcription"] GENERATION_TASKS = get_args(GenerationTask) -PoolingTask = Literal["embed", "classify", "score", "token_embed", "token_classify"] +PoolingTask = Literal[ + "embed", "classify", "score", "token_embed", "token_classify", "plugin" +] POOLING_TASKS = get_args(PoolingTask) SupportedTask = Literal[GenerationTask, PoolingTask]