diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 2695da5778aa..8c099b9531c5 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -8,7 +8,7 @@ import numpy as np import torch -from vllm.inputs import PromptType +from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger if TYPE_CHECKING: @@ -400,6 +400,7 @@ def validate_request( cls, prompt: PromptType, params: Union[SamplingParams, PoolingParams], + processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index d8807a72ba2f..83dd3e9c817a 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -5,7 +5,7 @@ import torch import vllm.envs as envs -from vllm.inputs import PromptType +from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger from vllm.sampling_params import SamplingParams, SamplingType @@ -150,6 +150,7 @@ def validate_request( cls, prompt: PromptType, params: Union[SamplingParams, PoolingParams], + processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" if isinstance(params, SamplingParams): diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 6d3290f16565..d5918f8a4bd8 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -202,12 +202,6 @@ def process_inputs( # TODO(woosuk): Support pooling models. # TODO(woosuk): Support encoder-decoder models. - - from vllm.platforms import current_platform - current_platform.validate_request( - prompt=prompt, - params=params, - ) self._validate_lora(lora_request) self._validate_params(params) if priority != 0: @@ -231,6 +225,12 @@ def process_inputs( prompt_adapter_request=prompt_adapter_request, return_mm_hashes=self.use_hash, ) + from vllm.platforms import current_platform + current_platform.validate_request( + prompt=prompt, + params=params, + processed_inputs=processed_inputs, + ) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) self._validate_model_inputs(processed_inputs, lora_request)