diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 35d001781110..aad7fc5303c1 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -18,6 +18,7 @@ from vllm.entrypoints.openai.api_server import build_async_engine_client from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.lora.request import LoRARequest +from vllm.sequence import SequenceGroupMetadata from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser @@ -292,3 +293,80 @@ async def test_engine_process_death(tmp_socket): await client.check_health() client.close() + + +def run_with_evil_input_processing(engine_args: AsyncEngineArgs, + ipc_path: str): + """Simulate an exception while preparing inputs for the model. + In the wild, this could be something like a multimodal input processor + failing on invalid image data.""" + + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + runner = engine.engine.model_executor.driver_worker.worker.model_runner + + # Raise error in the model runner when adding a sequence group. + # See class ModelInputForGPUBuilder + def raiser(_, seq_group_metadata: SequenceGroupMetadata): + if seq_group_metadata.request_id.startswith("evil"): + raise RAISED_ERROR(RAISED_VALUE) + + runner.builder.per_seq_group_compute_fns.append(raiser) + + # Run engine. + engine.start() + + +@pytest.mark.asyncio +async def test_failed_inputs(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_input_processing) as engine: + + client = await engine.make_client() + assert client.is_running + + # Engine should be healthy + await client.check_health() + + async def run_failing_request(): + async for _ in client.generate( + prompt="Hello my name is", + sampling_params=SamplingParams(max_tokens=10), + request_id="evil" + str(uuid.uuid4())): + pass + + async def run_passing_request(): + async for _ in client.generate( + prompt="Hello my name is", + sampling_params=SamplingParams(max_tokens=10), + request_id=str(uuid.uuid4())): + pass + + passing_tasks = [ + asyncio.create_task(run_passing_request()) for _ in range(10) + ] + failing_tasks = [ + asyncio.create_task(run_failing_request()) for _ in range(10) + ] + await asyncio.gather(*failing_tasks, return_exceptions=True) + await asyncio.gather(*passing_tasks) + + # All the bad inputs should have raised + for task in failing_tasks: + with pytest.raises(RAISED_ERROR): + task.result() + + # But all good inputs should have still succeeded + for task in passing_tasks: + task.result() + + # And the engine should remain healthy + assert not client.errored + await client.check_health() + + client.close() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3ce9a0461368..ef2cd68bc5eb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -60,6 +60,7 @@ from vllm.utils import (Counter, Device, deprecate_kwargs, resolve_obj_by_qualname, weak_bind) from vllm.version import __version__ as VLLM_VERSION +from vllm.worker.model_runner_base import InputProcessingError logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -410,6 +411,10 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} + # Flag to set when an input fails to process and the engine should run + # the next step without re-scheduling. + self._skip_scheduling_next_step = False + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -1334,7 +1339,11 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: # Skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current # batch has completed. - if not self._has_remaining_steps(seq_group_metadata_list): + # The scheduler is also skipped if a single request caused the last + # engine step to fail, and the previous schedule needs to be rerun. + if not self._has_remaining_steps( + seq_group_metadata_list + ) and not self._skip_scheduling_next_step: # Schedule iteration (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc @@ -1388,8 +1397,23 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: execute_model_req.async_callback = self.async_callbacks[ virtual_engine] - outputs = self.model_executor.execute_model( - execute_model_req=execute_model_req) + try: + outputs = self.model_executor.execute_model( + execute_model_req=execute_model_req) + self._skip_scheduling_next_step = False + except InputProcessingError as e: + # The input for this request cannot be processed, so we must + # abort it. If there are remaining requests in the batch that + # have been scheduled, they will be retried on the next step. + invalid_request_id = e.request_id + self._abort_and_cache_schedule( + request_id=invalid_request_id, + virtual_engine=virtual_engine, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + allow_async_output_proc=allow_async_output_proc) + # Raise so the caller is notified that this request failed + raise # We need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. @@ -1464,6 +1488,38 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: return ctx.request_outputs + def _abort_and_cache_schedule( + self, request_id: str, virtual_engine: int, + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduler_outputs: SchedulerOutputs, + allow_async_output_proc: bool) -> None: + """Aborts a single request, and caches the scheduler outputs minus that + request. This allows the next step to continue processing the remaining + requests without having to re-run the scheduler.""" + + # Abort the request and remove its sequence group from the current + # schedule + self.abort_request(request_id) + for i, metadata in enumerate(seq_group_metadata_list): + if metadata.request_id == request_id: + del seq_group_metadata_list[i] + break + for i, group in enumerate(scheduler_outputs.scheduled_seq_groups): + if group.seq_group.request_id == request_id: + del scheduler_outputs.scheduled_seq_groups[i] + break + + # If there are still other sequence groups left in the schedule, cache + # them and flag the engine to reuse the schedule. + if len(seq_group_metadata_list) > 0: + self._skip_scheduling_next_step = True + # Reuse multi-step caching logic + self._cache_scheduler_outputs_for_multi_step( + virtual_engine=virtual_engine, + scheduler_outputs=scheduler_outputs, + seq_group_metadata_list=seq_group_metadata_list, + allow_async_output_proc=allow_async_output_proc) + def _has_remaining_steps( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] ) -> bool: diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index ce24aa21514d..efea6ee2c69a 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -27,6 +27,7 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext +from vllm.worker.model_runner_base import InputProcessingError logger = init_logger(__name__) @@ -210,6 +211,14 @@ def engine_step(self) -> List[RequestOutput]: return self.engine.step() except SystemExit: raise + except InputProcessingError as e: + # Special case where we handle an error preparing the inputs for + # a single request in the batch + rpc_err = RPCError(request_id=e.request_id, + is_engine_errored=False, + exception=e.__cause__) + self._send_outputs(rpc_err) + return [] except BaseException as e: self._set_errored(e) rpc_err = RPCError(request_id=None, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 86dcde234f86..a37a3168bbbc 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -53,8 +53,8 @@ is_pin_memory_available, supports_dynamo, weak_ref_tensor) from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, - _add_attn_metadata_broadcastable_dict, + InputProcessingError, ModelRunnerBase, ModelRunnerInputBase, + ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) @@ -1216,7 +1216,12 @@ def _prepare_model_input_tensors( """ self.builder.prepare(finished_requests_ids) for seq_group_metadata in seq_group_metadata_list: - self.builder.add_seq_group(seq_group_metadata) + try: + self.builder.add_seq_group(seq_group_metadata) + except Exception as e: + # Raise an exception that tracks the ID of the bad request + raise InputProcessingError(seq_group_metadata.request_id, + str(e)) from e self.builder.reset_cached_inter_data() diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index bae37cb7155f..935325cb2e1c 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -261,3 +261,21 @@ def __init__( def __getattr__(self, attr): return getattr(self.model_runner, attr) + + +class InputProcessingError(Exception): + """This exception is raised when an error occurs preparing the inputs for + a single sequence group. + This allows the engine to gracefully handle errors with a single sequence + group without having to fail the entire batch. + """ + + def __init__(self, request_id, message): + """request_id is the id of the offending sequence group""" + self.request_id = request_id + self.message = message + super().__init__(self.message) + + def __str__(self): + return "Failed to prepare inputs for sequence group with request id: " \ + f"{self.request_id}, Error: {self.message}"