Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions docs/design/io_processor_plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ IOProcessorInput = TypeVar("IOProcessorInput")
IOProcessorOutput = TypeVar("IOProcessorOutput")

class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):

def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config

Expand Down Expand Up @@ -49,13 +48,24 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
request_id: str | None = None,
**kwargs,
) -> IOProcessorOutput:
collected_output = [item async for i, item in model_output]
# We cannot guarantee outputs are returned in the same order they were
# fed to vLLM.
# Let's sort them by id before post_processing
sorted_output = sorted(
[(i, item) async for i, item in model_output], key=lambda output: output[0]
)
collected_output = [output[1] for output in sorted_output]
return self.post_process(collected_output, request_id, **kwargs)

@abstractmethod
def parse_request(self, request: Any) -> IOProcessorInput:
raise NotImplementedError

def validate_or_generate_params(
self, params: SamplingParams | PoolingParams | None = None
) -> SamplingParams | PoolingParams:
return params or PoolingParams()

@abstractmethod
def output_to_response(
self, plugin_output: IOProcessorOutput
Expand All @@ -66,10 +76,10 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods.
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.

The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py).

An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples.
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples.

## Using an IO Processor plugin

Expand Down
2 changes: 1 addition & 1 deletion examples/offline_inference/prithvi_geospatial_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def run(self, input_data, location_coords):
}

prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
outputs = self.model.encode(prompt, use_tqdm=False)
outputs = self.model.encode(prompt, pooling_task="plugin", use_tqdm=False)

return outputs[0].outputs.data

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import torch

from vllm import LLM
from vllm.pooling_params import PoolingParams

# This example shows how to perform an offline inference that generates
# multimodal data. In this specific case this example will take a geotiff
# image as input, process it using the multimodal data processor, and
# perform inference.
# Requirement - install plugin at:
# https://github.com/christian-pinto/prithvi_io_processor_plugin
# Requirements:
# - install TerraTorch v1.1 (or later):
# pip install terratorch>=v1.1


def main():
Expand All @@ -36,16 +36,12 @@ def main():
# to avoid the model going OOM.
# The maximum number depends on the available GPU memory
max_num_seqs=32,
io_processor_plugin="prithvi_to_tiff",
io_processor_plugin="terratorch_segmentation",
model_impl="terratorch",
enable_mm_embeds=True,
)

pooling_params = PoolingParams(task="token_classify", activation=False)
pooler_output = llm.encode(
img_prompt,
pooling_params=pooling_params,
)
pooler_output = llm.encode(img_prompt, pooling_task="plugin")
output = pooler_output[0].outputs

print(output)
Expand Down
7 changes: 3 additions & 4 deletions examples/online_serving/prithvi_geospatial_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
# image as input, process it using the multimodal data processor, and
# perform inference.
# Requirements :
# - install plugin at:
# https://github.com/christian-pinto/prithvi_io_processor_plugin
# - install TerraTorch v1.1 (or later):
# pip install terratorch>=v1.1
# - start vllm in serving mode with the below args
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
# --model-impl terratorch
# --task embed --trust-remote-code
# --skip-tokenizer-init --enforce-eager
# --io-processor-plugin prithvi_to_tiff
# --io-processor-plugin terratorch_segmentation
# --enable-mm-embeds


Expand All @@ -35,7 +35,6 @@ def main():
},
"priority": 0,
"model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
"softmax": False,
}

ret = requests.post(server_endpoint, json=request_payload_url)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/multimodal/pooling/test_prithvi_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _run_test(
max_num_seqs=32,
default_torch_num_threads=1,
) as vllm_model:
vllm_model.llm.encode(prompt, pooling_task="token_classify")
vllm_model.llm.encode(prompt, pooling_task="plugin")


MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,9 +368,9 @@ def post_process(
out_format = "b64_json"

for output in model_output:
y_hat = output.outputs.data.argmax(dim=1)
y_hat = output.outputs.data.argmax(dim=0)
pred = torch.nn.functional.interpolate(
y_hat.unsqueeze(1).float(),
y_hat[None, None, ...].float(),
size=self.img_size,
mode="nearest",
)
Expand Down
7 changes: 1 addition & 6 deletions tests/plugins_tests/test_io_processor_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from vllm.config import VllmConfig
from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams

MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"

Expand Down Expand Up @@ -94,8 +93,6 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
out_data_format="b64_json",
)

pooling_params = PoolingParams(activation=False)

with vllm_runner(
model_name,
runner="pooling",
Expand All @@ -109,9 +106,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
model_impl="terratorch",
io_processor_plugin="prithvi_to_tiff",
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(
img_prompt, pooling_params=pooling_params, pooling_task="token_classify"
)
pooler_output = llm_runner.get_llm().encode(img_prompt, pooling_task="plugin")
output = pooler_output[0].outputs

# verify the output is formatted as expected for this plugin
Expand Down
41 changes: 28 additions & 13 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,19 +1024,6 @@ def encode(
"pooling model."
)

if pooling_task not in self.supported_tasks:
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")

if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()

for param in as_iter(pooling_params):
param.verify(pooling_task, model_config)
# for backwards compatibility
if truncate_prompt_tokens is not None:
param.truncate_prompt_tokens = truncate_prompt_tokens

io_processor_prompt = False
if isinstance(prompts, dict) and "data" in prompts:
io_processor_prompt = True
Expand All @@ -1054,6 +1041,34 @@ def encode(
# obtain the actual model prompts from the pre-processor
prompts = self.io_processor.pre_process(prompt=validated_prompt)

if io_processor_prompt:
assert self.io_processor is not None
if is_list_of(pooling_params, PoolingParams):
validated_pooling_params: list[PoolingParams] = []
for param in as_iter(pooling_params):
validated_pooling_params.append(
self.io_processor.validate_or_generate_params(param)
)
pooling_params = validated_pooling_params
else:
assert not isinstance(pooling_params, Sequence)
pooling_params = self.io_processor.validate_or_generate_params(
pooling_params
)
else:
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()

if pooling_task not in self.supported_tasks:
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")

for param in as_iter(pooling_params):
param.verify(pooling_task, model_config)
# for backwards compatibility
if truncate_prompt_tokens is not None:
param.truncate_prompt_tokens = truncate_prompt_tokens

self._validate_and_add_requests(
prompts=prompts,
params=pooling_params,
Expand Down
7 changes: 6 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,7 +1748,12 @@ async def init_app_state(
log_error_stack=args.log_error_stack,
)
)
if ("token_embed" in supported_tasks or "token_classify" in supported_tasks)
if (
any(
task in supported_tasks
for task in ["token_embed", "token_classify", "plugin"]
)
)
else None
)
state.openai_serving_embedding = (
Expand Down
7 changes: 1 addition & 6 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,11 +1707,6 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
if the served model does not use priority scheduling.
"""
data: T
"""
When using plugins IOProcessor plugins, the actual input is processed
by the plugin itself. Hence, we use a generic type for the request data
"""
activation: bool = False

encoding_format: EncodingFormat = "float"
embed_dtype: EmbedDType = Field(
Expand All @@ -1732,7 +1727,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
)

def to_pooling_params(self):
return PoolingParams(task="token_classify", activation=self.activation)
return PoolingParams()


class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
Expand Down
13 changes: 11 additions & 2 deletions vllm/entrypoints/openai/serving_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.tasks import SupportedTask
from vllm.tasks import PoolingTask, SupportedTask
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import (
EmbedDType,
Expand Down Expand Up @@ -161,12 +161,21 @@ async def create_pooling(
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params()
if is_io_processor_request:
assert self.io_processor is not None and isinstance(
request, IOProcessorRequest
)
pooling_params = self.io_processor.validate_or_generate_params()
else:
pooling_params = request.to_pooling_params()

pooling_task: PoolingTask
if "token_embed" in self.supported_tasks:
pooling_task = "token_embed"
elif "token_classify" in self.supported_tasks:
pooling_task = "token_classify"
elif "plugin" in self.supported_tasks:
pooling_task = "plugin"
else:
return self.create_error_response(
f"pooling_task must be one of {self.supported_tasks}."
Expand Down
12 changes: 12 additions & 0 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,18 @@ def forward(
raise NotImplementedError


class DummyPooler(Pooler):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"plugin", "score"}

def forward(
self,
hidden_states: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return hidden_states


class PoolerHead(nn.Module):
def __init__(self, activation: PoolerActivation) -> None:
super().__init__()
Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/models/terratorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.pooler import DispatchPooler, DummyPooler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
Expand Down Expand Up @@ -249,9 +249,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

self.pooler = DispatchPooler(
{"token_classify": Pooler.for_token_classify(pooler_config)}
)
self.pooler = DispatchPooler({"plugin": DummyPooler()})

def get_input_embeddings(
self,
Expand Down
7 changes: 7 additions & 0 deletions vllm/plugins/io_processors/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams

IOProcessorInput = TypeVar("IOProcessorInput")
IOProcessorOutput = TypeVar("IOProcessorOutput")
Expand Down Expand Up @@ -63,6 +65,11 @@ async def post_process_async(
def parse_request(self, request: Any) -> IOProcessorInput:
raise NotImplementedError

def validate_or_generate_params(
self, params: SamplingParams | PoolingParams | None = None
) -> SamplingParams | PoolingParams:
return params or PoolingParams()

@abstractmethod
def output_to_response(
self, plugin_output: IOProcessorOutput
Expand Down
5 changes: 5 additions & 0 deletions vllm/pooling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def verify(
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
raise ValueError(msg)

# plugin task uses io_processor.parse_request to verify inputs,
# skipping PoolingParams verify
if self.task == "plugin":
return

# NOTE: Task validation needs to done against the model instance,
# which is not available in model config. So, it's not included
# in this method
Expand Down
4 changes: 3 additions & 1 deletion vllm/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
GenerationTask = Literal["generate", "transcription"]
GENERATION_TASKS = get_args(GenerationTask)

PoolingTask = Literal["embed", "classify", "score", "token_embed", "token_classify"]
PoolingTask = Literal[
"embed", "classify", "score", "token_embed", "token_classify", "plugin"
]
POOLING_TASKS = get_args(PoolingTask)

SupportedTask = Literal[GenerationTask, PoolingTask]