Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions tests/mq_llm_engine/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.lora.request import LoRARequest
from vllm.sequence import SequenceGroupMetadata
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser

Expand Down Expand Up @@ -292,3 +293,80 @@ async def test_engine_process_death(tmp_socket):
await client.check_health()

client.close()


def run_with_evil_input_processing(engine_args: AsyncEngineArgs,
ipc_path: str):
"""Simulate an exception while preparing inputs for the model.
In the wild, this could be something like a multimodal input processor
failing on invalid image data."""

# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)

runner = engine.engine.model_executor.driver_worker.worker.model_runner

# Raise error in the model runner when adding a sequence group.
# See class ModelInputForGPUBuilder
def raiser(_, seq_group_metadata: SequenceGroupMetadata):
if seq_group_metadata.request_id.startswith("evil"):
raise RAISED_ERROR(RAISED_VALUE)

runner.builder.per_seq_group_compute_fns.append(raiser)

# Run engine.
engine.start()


@pytest.mark.asyncio
async def test_failed_inputs(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_input_processing) as engine:

client = await engine.make_client()
assert client.is_running

# Engine should be healthy
await client.check_health()

async def run_failing_request():
async for _ in client.generate(
prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
request_id="evil" + str(uuid.uuid4())):
pass

async def run_passing_request():
async for _ in client.generate(
prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
request_id=str(uuid.uuid4())):
pass

passing_tasks = [
asyncio.create_task(run_passing_request()) for _ in range(10)
]
failing_tasks = [
asyncio.create_task(run_failing_request()) for _ in range(10)
]
await asyncio.gather(*failing_tasks, return_exceptions=True)
await asyncio.gather(*passing_tasks)

# All the bad inputs should have raised
for task in failing_tasks:
with pytest.raises(RAISED_ERROR):
task.result()

# But all good inputs should have still succeeded
for task in passing_tasks:
task.result()

# And the engine should remain healthy
assert not client.errored
await client.check_health()

client.close()
62 changes: 59 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from vllm.utils import (Counter, Device, deprecate_kwargs,
resolve_obj_by_qualname, weak_bind)
from vllm.version import __version__ as VLLM_VERSION
from vllm.worker.model_runner_base import InputProcessingError

logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
Expand Down Expand Up @@ -410,6 +411,10 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:

self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

# Flag to set when an input fails to process and the engine should run
# the next step without re-scheduling.
self._skip_scheduling_next_step = False

def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).

Expand Down Expand Up @@ -1334,7 +1339,11 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
# The scheduler is also skipped if a single request caused the last
# engine step to fail, and the previous schedule needs to be rerun.
if not self._has_remaining_steps(
seq_group_metadata_list
) and not self._skip_scheduling_next_step:
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
Expand Down Expand Up @@ -1388,8 +1397,23 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]

outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
try:
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
self._skip_scheduling_next_step = False
except InputProcessingError as e:
# The input for this request cannot be processed, so we must
# abort it. If there are remaining requests in the batch that
# have been scheduled, they will be retried on the next step.
invalid_request_id = e.request_id
self._abort_and_cache_schedule(
request_id=invalid_request_id,
virtual_engine=virtual_engine,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
allow_async_output_proc=allow_async_output_proc)
# Raise so the caller is notified that this request failed
raise

# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
Expand Down Expand Up @@ -1464,6 +1488,38 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:

return ctx.request_outputs

def _abort_and_cache_schedule(
self, request_id: str, virtual_engine: int,
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduler_outputs: SchedulerOutputs,
allow_async_output_proc: bool) -> None:
"""Aborts a single request, and caches the scheduler outputs minus that
request. This allows the next step to continue processing the remaining
requests without having to re-run the scheduler."""

# Abort the request and remove its sequence group from the current
# schedule
self.abort_request(request_id)
for i, metadata in enumerate(seq_group_metadata_list):
if metadata.request_id == request_id:
del seq_group_metadata_list[i]
break
for i, group in enumerate(scheduler_outputs.scheduled_seq_groups):
if group.seq_group.request_id == request_id:
del scheduler_outputs.scheduled_seq_groups[i]
break

# If there are still other sequence groups left in the schedule, cache
# them and flag the engine to reuse the schedule.
if len(seq_group_metadata_list) > 0:
self._skip_scheduling_next_step = True
# Reuse multi-step caching logic
self._cache_scheduler_outputs_for_multi_step(
virtual_engine=virtual_engine,
scheduler_outputs=scheduler_outputs,
seq_group_metadata_list=seq_group_metadata_list,
allow_async_output_proc=allow_async_output_proc)

def _has_remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
) -> bool:
Expand Down
9 changes: 9 additions & 0 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
from vllm.worker.model_runner_base import InputProcessingError

logger = init_logger(__name__)

Expand Down Expand Up @@ -210,6 +211,14 @@ def engine_step(self) -> List[RequestOutput]:
return self.engine.step()
except SystemExit:
raise
except InputProcessingError as e:
# Special case where we handle an error preparing the inputs for
# a single request in the batch
rpc_err = RPCError(request_id=e.request_id,
is_engine_errored=False,
exception=e.__cause__)
self._send_outputs(rpc_err)
return []
except BaseException as e:
self._set_errored(e)
rpc_err = RPCError(request_id=None,
Expand Down
11 changes: 8 additions & 3 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@
is_pin_memory_available, supports_dynamo,
weak_ref_tensor)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
InputProcessingError, ModelRunnerBase, ModelRunnerInputBase,
ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
Expand Down Expand Up @@ -1216,7 +1216,12 @@ def _prepare_model_input_tensors(
"""
self.builder.prepare(finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
self.builder.add_seq_group(seq_group_metadata)
try:
self.builder.add_seq_group(seq_group_metadata)
except Exception as e:
# Raise an exception that tracks the ID of the bad request
raise InputProcessingError(seq_group_metadata.request_id,
str(e)) from e

self.builder.reset_cached_inter_data()

Expand Down
18 changes: 18 additions & 0 deletions vllm/worker/model_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,21 @@ def __init__(

def __getattr__(self, attr):
return getattr(self.model_runner, attr)


class InputProcessingError(Exception):
"""This exception is raised when an error occurs preparing the inputs for
a single sequence group.
This allows the engine to gracefully handle errors with a single sequence
group without having to fail the entire batch.
"""

def __init__(self, request_id, message):
"""request_id is the id of the offending sequence group"""
self.request_id = request_id
self.message = message
super().__init__(self.message)

def __str__(self):
return "Failed to prepare inputs for sequence group with request id: " \
f"{self.request_id}, Error: {self.message}"