Skip to content

Commit b8bbe59

Browse files
noooopgemini-code-assist[bot]christian-pinto
authored andcommitted
[Frontend][4/N] Improve all pooling task | Add plugin pooling task (vllm-project#26973)
Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Christian Pinto <christian.pinto@ibm.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Christian Pinto <christian.pinto@ibm.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
1 parent 647b5ae commit b8bbe59

File tree

16 files changed

+102
-54
lines changed

16 files changed

+102
-54
lines changed

docs/design/io_processor_plugins.md

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ IOProcessorInput = TypeVar("IOProcessorInput")
1313
IOProcessorOutput = TypeVar("IOProcessorOutput")
1414

1515
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
16-
1716
def __init__(self, vllm_config: VllmConfig):
1817
self.vllm_config = vllm_config
1918

@@ -49,13 +48,24 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
4948
request_id: str | None = None,
5049
**kwargs,
5150
) -> IOProcessorOutput:
52-
collected_output = [item async for i, item in model_output]
51+
# We cannot guarantee outputs are returned in the same order they were
52+
# fed to vLLM.
53+
# Let's sort them by id before post_processing
54+
sorted_output = sorted(
55+
[(i, item) async for i, item in model_output], key=lambda output: output[0]
56+
)
57+
collected_output = [output[1] for output in sorted_output]
5358
return self.post_process(collected_output, request_id, **kwargs)
5459

5560
@abstractmethod
5661
def parse_request(self, request: Any) -> IOProcessorInput:
5762
raise NotImplementedError
5863

64+
def validate_or_generate_params(
65+
self, params: SamplingParams | PoolingParams | None = None
66+
) -> SamplingParams | PoolingParams:
67+
return params or PoolingParams()
68+
5969
@abstractmethod
6070
def output_to_response(
6171
self, plugin_output: IOProcessorOutput
@@ -66,10 +76,10 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
6676
The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods.
6777
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
6878
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
69-
79+
The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
7080
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py).
7181

72-
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples.
82+
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples.
7383

7484
## Using an IO Processor plugin
7585

examples/offline_inference/prithvi_geospatial_mae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def run(self, input_data, location_coords):
6464
}
6565

6666
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
67-
outputs = self.model.encode(prompt, use_tqdm=False)
67+
outputs = self.model.encode(prompt, pooling_task="plugin", use_tqdm=False)
6868

6969
return outputs[0].outputs.data
7070

examples/offline_inference/prithvi_geospatial_mae_io_processor.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
import torch
77

88
from vllm import LLM
9-
from vllm.pooling_params import PoolingParams
109

1110
# This example shows how to perform an offline inference that generates
1211
# multimodal data. In this specific case this example will take a geotiff
1312
# image as input, process it using the multimodal data processor, and
1413
# perform inference.
15-
# Requirement - install plugin at:
16-
# https://github.com/christian-pinto/prithvi_io_processor_plugin
14+
# Requirements:
15+
# - install TerraTorch v1.1 (or later):
16+
# pip install terratorch>=v1.1
1717

1818

1919
def main():
@@ -36,16 +36,12 @@ def main():
3636
# to avoid the model going OOM.
3737
# The maximum number depends on the available GPU memory
3838
max_num_seqs=32,
39-
io_processor_plugin="prithvi_to_tiff",
39+
io_processor_plugin="terratorch_segmentation",
4040
model_impl="terratorch",
4141
enable_mm_embeds=True,
4242
)
4343

44-
pooling_params = PoolingParams(task="token_classify", activation=False)
45-
pooler_output = llm.encode(
46-
img_prompt,
47-
pooling_params=pooling_params,
48-
)
44+
pooler_output = llm.encode(img_prompt, pooling_task="plugin")
4945
output = pooler_output[0].outputs
5046

5147
print(output)

examples/online_serving/prithvi_geospatial_mae.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
# image as input, process it using the multimodal data processor, and
1212
# perform inference.
1313
# Requirements :
14-
# - install plugin at:
15-
# https://github.com/christian-pinto/prithvi_io_processor_plugin
14+
# - install TerraTorch v1.1 (or later):
15+
# pip install terratorch>=v1.1
1616
# - start vllm in serving mode with the below args
1717
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
1818
# --model-impl terratorch
1919
# --task embed --trust-remote-code
2020
# --skip-tokenizer-init --enforce-eager
21-
# --io-processor-plugin prithvi_to_tiff
21+
# --io-processor-plugin terratorch_segmentation
2222
# --enable-mm-embeds
2323

2424

@@ -35,7 +35,6 @@ def main():
3535
},
3636
"priority": 0,
3737
"model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
38-
"softmax": False,
3938
}
4039

4140
ret = requests.post(server_endpoint, json=request_payload_url)

tests/models/multimodal/pooling/test_prithvi_mae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _run_test(
4040
max_num_seqs=32,
4141
default_torch_num_threads=1,
4242
) as vllm_model:
43-
vllm_model.llm.encode(prompt, pooling_task="token_classify")
43+
vllm_model.llm.encode(prompt, pooling_task="plugin")
4444

4545

4646
MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]

tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,9 +368,9 @@ def post_process(
368368
out_format = "b64_json"
369369

370370
for output in model_output:
371-
y_hat = output.outputs.data.argmax(dim=1)
371+
y_hat = output.outputs.data.argmax(dim=0)
372372
pred = torch.nn.functional.interpolate(
373-
y_hat.unsqueeze(1).float(),
373+
y_hat[None, None, ...].float(),
374374
size=self.img_size,
375375
mode="nearest",
376376
)

tests/plugins_tests/test_io_processor_plugins.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from vllm.config import VllmConfig
1010
from vllm.entrypoints.openai.protocol import IOProcessorResponse
1111
from vllm.plugins.io_processors import get_io_processor
12-
from vllm.pooling_params import PoolingParams
1312

1413
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
1514

@@ -94,8 +93,6 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
9493
out_data_format="b64_json",
9594
)
9695

97-
pooling_params = PoolingParams(activation=False)
98-
9996
with vllm_runner(
10097
model_name,
10198
runner="pooling",
@@ -109,9 +106,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
109106
model_impl="terratorch",
110107
io_processor_plugin="prithvi_to_tiff",
111108
) as llm_runner:
112-
pooler_output = llm_runner.get_llm().encode(
113-
img_prompt, pooling_params=pooling_params, pooling_task="token_classify"
114-
)
109+
pooler_output = llm_runner.get_llm().encode(img_prompt, pooling_task="plugin")
115110
output = pooler_output[0].outputs
116111

117112
# verify the output is formatted as expected for this plugin

vllm/entrypoints/llm.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,19 +1024,6 @@ def encode(
10241024
"pooling model."
10251025
)
10261026

1027-
if pooling_task not in self.supported_tasks:
1028-
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
1029-
1030-
if pooling_params is None:
1031-
# Use default pooling params.
1032-
pooling_params = PoolingParams()
1033-
1034-
for param in as_iter(pooling_params):
1035-
param.verify(pooling_task, model_config)
1036-
# for backwards compatibility
1037-
if truncate_prompt_tokens is not None:
1038-
param.truncate_prompt_tokens = truncate_prompt_tokens
1039-
10401027
io_processor_prompt = False
10411028
if isinstance(prompts, dict) and "data" in prompts:
10421029
io_processor_prompt = True
@@ -1054,6 +1041,34 @@ def encode(
10541041
# obtain the actual model prompts from the pre-processor
10551042
prompts = self.io_processor.pre_process(prompt=validated_prompt)
10561043

1044+
if io_processor_prompt:
1045+
assert self.io_processor is not None
1046+
if is_list_of(pooling_params, PoolingParams):
1047+
validated_pooling_params: list[PoolingParams] = []
1048+
for param in as_iter(pooling_params):
1049+
validated_pooling_params.append(
1050+
self.io_processor.validate_or_generate_params(param)
1051+
)
1052+
pooling_params = validated_pooling_params
1053+
else:
1054+
assert not isinstance(pooling_params, Sequence)
1055+
pooling_params = self.io_processor.validate_or_generate_params(
1056+
pooling_params
1057+
)
1058+
else:
1059+
if pooling_params is None:
1060+
# Use default pooling params.
1061+
pooling_params = PoolingParams()
1062+
1063+
if pooling_task not in self.supported_tasks:
1064+
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
1065+
1066+
for param in as_iter(pooling_params):
1067+
param.verify(pooling_task, model_config)
1068+
# for backwards compatibility
1069+
if truncate_prompt_tokens is not None:
1070+
param.truncate_prompt_tokens = truncate_prompt_tokens
1071+
10571072
self._validate_and_add_requests(
10581073
prompts=prompts,
10591074
params=pooling_params,

vllm/entrypoints/openai/api_server.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1748,7 +1748,12 @@ async def init_app_state(
17481748
log_error_stack=args.log_error_stack,
17491749
)
17501750
)
1751-
if ("token_embed" in supported_tasks or "token_classify" in supported_tasks)
1751+
if (
1752+
any(
1753+
task in supported_tasks
1754+
for task in ["token_embed", "token_classify", "plugin"]
1755+
)
1756+
)
17521757
else None
17531758
)
17541759
state.openai_serving_embedding = (

vllm/entrypoints/openai/protocol.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,11 +1707,6 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
17071707
if the served model does not use priority scheduling.
17081708
"""
17091709
data: T
1710-
"""
1711-
When using plugins IOProcessor plugins, the actual input is processed
1712-
by the plugin itself. Hence, we use a generic type for the request data
1713-
"""
1714-
activation: bool = False
17151710

17161711
encoding_format: EncodingFormat = "float"
17171712
embed_dtype: EmbedDType = Field(
@@ -1732,7 +1727,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
17321727
)
17331728

17341729
def to_pooling_params(self):
1735-
return PoolingParams(task="token_classify", activation=self.activation)
1730+
return PoolingParams()
17361731

17371732

17381733
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):

0 commit comments

Comments
 (0)