Skip to content

Commit 3f808cc

Browse files
authored
[Bugfix] Do not crash V0 engine on input errors (#13101)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
1 parent ec8a5e5 commit 3f808cc

File tree

5 files changed

+172
-6
lines changed

5 files changed

+172
-6
lines changed

tests/mq_llm_engine/test_error_handling.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.entrypoints.openai.api_server import build_async_engine_client
1919
from vllm.entrypoints.openai.cli_args import make_arg_parser
2020
from vllm.lora.request import LoRARequest
21+
from vllm.sequence import SequenceGroupMetadata
2122
from vllm.usage.usage_lib import UsageContext
2223
from vllm.utils import FlexibleArgumentParser
2324

@@ -292,3 +293,80 @@ async def test_engine_process_death(tmp_socket):
292293
await client.check_health()
293294

294295
client.close()
296+
297+
298+
def run_with_evil_input_processing(engine_args: AsyncEngineArgs,
299+
ipc_path: str):
300+
"""Simulate an exception while preparing inputs for the model.
301+
In the wild, this could be something like a multimodal input processor
302+
failing on invalid image data."""
303+
304+
# Make engine.
305+
engine = MQLLMEngine.from_engine_args(
306+
engine_args=engine_args,
307+
usage_context=UsageContext.UNKNOWN_CONTEXT,
308+
ipc_path=ipc_path)
309+
310+
runner = engine.engine.model_executor.driver_worker.worker.model_runner
311+
312+
# Raise error in the model runner when adding a sequence group.
313+
# See class ModelInputForGPUBuilder
314+
def raiser(_, seq_group_metadata: SequenceGroupMetadata):
315+
if seq_group_metadata.request_id.startswith("evil"):
316+
raise RAISED_ERROR(RAISED_VALUE)
317+
318+
runner.builder.per_seq_group_compute_fns.append(raiser)
319+
320+
# Run engine.
321+
engine.start()
322+
323+
324+
@pytest.mark.asyncio
325+
async def test_failed_inputs(tmp_socket):
326+
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
327+
ipc_path=tmp_socket,
328+
run_fn=run_with_evil_input_processing) as engine:
329+
330+
client = await engine.make_client()
331+
assert client.is_running
332+
333+
# Engine should be healthy
334+
await client.check_health()
335+
336+
async def run_failing_request():
337+
async for _ in client.generate(
338+
prompt="Hello my name is",
339+
sampling_params=SamplingParams(max_tokens=10),
340+
request_id="evil" + str(uuid.uuid4())):
341+
pass
342+
343+
async def run_passing_request():
344+
async for _ in client.generate(
345+
prompt="Hello my name is",
346+
sampling_params=SamplingParams(max_tokens=10),
347+
request_id=str(uuid.uuid4())):
348+
pass
349+
350+
passing_tasks = [
351+
asyncio.create_task(run_passing_request()) for _ in range(10)
352+
]
353+
failing_tasks = [
354+
asyncio.create_task(run_failing_request()) for _ in range(10)
355+
]
356+
await asyncio.gather(*failing_tasks, return_exceptions=True)
357+
await asyncio.gather(*passing_tasks)
358+
359+
# All the bad inputs should have raised
360+
for task in failing_tasks:
361+
with pytest.raises(RAISED_ERROR):
362+
task.result()
363+
364+
# But all good inputs should have still succeeded
365+
for task in passing_tasks:
366+
task.result()
367+
368+
# And the engine should remain healthy
369+
assert not client.errored
370+
await client.check_health()
371+
372+
client.close()

vllm/engine/llm_engine.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from vllm.utils import (Counter, Device, deprecate_kwargs,
6161
resolve_obj_by_qualname, weak_bind)
6262
from vllm.version import __version__ as VLLM_VERSION
63+
from vllm.worker.model_runner_base import InputProcessingError
6364

6465
logger = init_logger(__name__)
6566
_LOCAL_LOGGING_INTERVAL_SEC = 5
@@ -410,6 +411,10 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
410411

411412
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
412413

414+
# Flag to set when an input fails to process and the engine should run
415+
# the next step without re-scheduling.
416+
self._skip_scheduling_next_step = False
417+
413418
def _initialize_kv_caches(self) -> None:
414419
"""Initialize the KV cache in the worker(s).
415420
@@ -1334,7 +1339,11 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
13341339
# Skip the scheduler if there are any remaining steps in the seq groups.
13351340
# This ensures that the scheduler is only called again when the current
13361341
# batch has completed.
1337-
if not self._has_remaining_steps(seq_group_metadata_list):
1342+
# The scheduler is also skipped if a single request caused the last
1343+
# engine step to fail, and the previous schedule needs to be rerun.
1344+
if not self._has_remaining_steps(
1345+
seq_group_metadata_list
1346+
) and not self._skip_scheduling_next_step:
13381347
# Schedule iteration
13391348
(seq_group_metadata_list, scheduler_outputs,
13401349
allow_async_output_proc
@@ -1388,8 +1397,23 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
13881397
execute_model_req.async_callback = self.async_callbacks[
13891398
virtual_engine]
13901399

1391-
outputs = self.model_executor.execute_model(
1392-
execute_model_req=execute_model_req)
1400+
try:
1401+
outputs = self.model_executor.execute_model(
1402+
execute_model_req=execute_model_req)
1403+
self._skip_scheduling_next_step = False
1404+
except InputProcessingError as e:
1405+
# The input for this request cannot be processed, so we must
1406+
# abort it. If there are remaining requests in the batch that
1407+
# have been scheduled, they will be retried on the next step.
1408+
invalid_request_id = e.request_id
1409+
self._abort_and_cache_schedule(
1410+
request_id=invalid_request_id,
1411+
virtual_engine=virtual_engine,
1412+
seq_group_metadata_list=seq_group_metadata_list,
1413+
scheduler_outputs=scheduler_outputs,
1414+
allow_async_output_proc=allow_async_output_proc)
1415+
# Raise so the caller is notified that this request failed
1416+
raise
13931417

13941418
# We need to do this here so that last step's sampled_token_ids can
13951419
# be passed to the next iteration for PP.
@@ -1464,6 +1488,38 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
14641488

14651489
return ctx.request_outputs
14661490

1491+
def _abort_and_cache_schedule(
1492+
self, request_id: str, virtual_engine: int,
1493+
seq_group_metadata_list: List[SequenceGroupMetadata],
1494+
scheduler_outputs: SchedulerOutputs,
1495+
allow_async_output_proc: bool) -> None:
1496+
"""Aborts a single request, and caches the scheduler outputs minus that
1497+
request. This allows the next step to continue processing the remaining
1498+
requests without having to re-run the scheduler."""
1499+
1500+
# Abort the request and remove its sequence group from the current
1501+
# schedule
1502+
self.abort_request(request_id)
1503+
for i, metadata in enumerate(seq_group_metadata_list):
1504+
if metadata.request_id == request_id:
1505+
del seq_group_metadata_list[i]
1506+
break
1507+
for i, group in enumerate(scheduler_outputs.scheduled_seq_groups):
1508+
if group.seq_group.request_id == request_id:
1509+
del scheduler_outputs.scheduled_seq_groups[i]
1510+
break
1511+
1512+
# If there are still other sequence groups left in the schedule, cache
1513+
# them and flag the engine to reuse the schedule.
1514+
if len(seq_group_metadata_list) > 0:
1515+
self._skip_scheduling_next_step = True
1516+
# Reuse multi-step caching logic
1517+
self._cache_scheduler_outputs_for_multi_step(
1518+
virtual_engine=virtual_engine,
1519+
scheduler_outputs=scheduler_outputs,
1520+
seq_group_metadata_list=seq_group_metadata_list,
1521+
allow_async_output_proc=allow_async_output_proc)
1522+
14671523
def _has_remaining_steps(
14681524
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
14691525
) -> bool:

vllm/engine/multiprocessing/engine.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from vllm.logger import init_logger
2828
from vllm.outputs import RequestOutput
2929
from vllm.usage.usage_lib import UsageContext
30+
from vllm.worker.model_runner_base import InputProcessingError
3031

3132
logger = init_logger(__name__)
3233

@@ -210,6 +211,14 @@ def engine_step(self) -> List[RequestOutput]:
210211
return self.engine.step()
211212
except SystemExit:
212213
raise
214+
except InputProcessingError as e:
215+
# Special case where we handle an error preparing the inputs for
216+
# a single request in the batch
217+
rpc_err = RPCError(request_id=e.request_id,
218+
is_engine_errored=False,
219+
exception=e.__cause__)
220+
self._send_outputs(rpc_err)
221+
return []
213222
except BaseException as e:
214223
self._set_errored(e)
215224
rpc_err = RPCError(request_id=None,

vllm/worker/model_runner.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@
5353
is_pin_memory_available, supports_dynamo,
5454
weak_ref_tensor)
5555
from vllm.worker.model_runner_base import (
56-
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
57-
_add_attn_metadata_broadcastable_dict,
56+
InputProcessingError, ModelRunnerBase, ModelRunnerInputBase,
57+
ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict,
5858
_add_sampling_metadata_broadcastable_dict,
5959
_init_attn_metadata_from_tensor_dict,
6060
_init_sampling_metadata_from_tensor_dict)
@@ -1216,7 +1216,12 @@ def _prepare_model_input_tensors(
12161216
"""
12171217
self.builder.prepare(finished_requests_ids)
12181218
for seq_group_metadata in seq_group_metadata_list:
1219-
self.builder.add_seq_group(seq_group_metadata)
1219+
try:
1220+
self.builder.add_seq_group(seq_group_metadata)
1221+
except Exception as e:
1222+
# Raise an exception that tracks the ID of the bad request
1223+
raise InputProcessingError(seq_group_metadata.request_id,
1224+
str(e)) from e
12201225

12211226
self.builder.reset_cached_inter_data()
12221227

vllm/worker/model_runner_base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,21 @@ def __init__(
261261

262262
def __getattr__(self, attr):
263263
return getattr(self.model_runner, attr)
264+
265+
266+
class InputProcessingError(Exception):
267+
"""This exception is raised when an error occurs preparing the inputs for
268+
a single sequence group.
269+
This allows the engine to gracefully handle errors with a single sequence
270+
group without having to fail the entire batch.
271+
"""
272+
273+
def __init__(self, request_id, message):
274+
"""request_id is the id of the offending sequence group"""
275+
self.request_id = request_id
276+
self.message = message
277+
super().__init__(self.message)
278+
279+
def __str__(self):
280+
return "Failed to prepare inputs for sequence group with request id: " \
281+
f"{self.request_id}, Error: {self.message}"

0 commit comments

Comments
 (0)