From df0fb3f952eb4b581cfa107fa1a7c97ab6b8bd9f Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 6 Feb 2025 12:57:48 -0700 Subject: [PATCH 1/8] :goal_net: Handle input errors in engine Signed-off-by: Joe Runde --- vllm/engine/multiprocessing/engine.py | 14 +++++ vllm/worker/model_runner.py | 91 ++++++++++++++++++--------- 2 files changed, 76 insertions(+), 29 deletions(-) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index a0dd79586588..2074162164c6 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -26,6 +26,7 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext +from vllm.worker.model_runner import InputProcessingError logger = init_logger(__name__) @@ -209,6 +210,16 @@ def engine_step(self) -> List[RequestOutput]: return self.engine.step() except SystemExit: raise + except InputProcessingError as e: + # Special case where we handle an error preparing the inputs for + # a single request in the batch + rpc_err = RPCError(request_id=e.request_id, + is_engine_errored=False, + exception=e.__cause__) + self._send_outputs(rpc_err) + # We must abort the request immediately so that the engine does not + # try to continue to process it in the next step + self.engine.abort_request(e.request_id) except BaseException as e: self._set_errored(e) rpc_err = RPCError(request_id=None, @@ -262,6 +273,7 @@ def _handle_process_request(self, request: RPCProcessRequest): self._send_outputs(rpc_err) try: + print("\n Running add_request \n") self.engine.add_request( request_id=request_id, prompt=request.prompt, @@ -275,6 +287,8 @@ def _handle_process_request(self, request: RPCProcessRequest): logger.info("Added request %s.", request.request_id) except Exception as e: + print("\n Caught add_request failure\n") + # We do not set self._errored = True here, since the error # is due to an issue adding this request to the engine, # rather than an issue with the engine itself. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0bbba55b3b3f..222e1039def2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -75,6 +75,24 @@ torch._dynamo.config.accumulated_cache_size_limit = 128 +class InputProcessingError(Exception): + """This exception is raised when an error occurs preparing the inputs for + a single sequence group. + This allows the engine to gracefully handle errors with a single sequence + group without having to fail the entire batch. + """ + + def __init__(self, request_id, message): + """request_id is the id of the offending sequence group""" + self.request_id = request_id + self.message = message + super().__init__(self.message) + + def __str__(self): + return "Failed to prepare inputs for sequence group with request id: " \ + f"{self.request_id}, Error: {self.message}" + + @dataclass(frozen=True) class ModelInputForGPU(ModelRunnerInputBase): """ @@ -683,9 +701,13 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, if not mm_data: return + print("FOOOOOOOO") + if self.runner.mm_registry.has_processor(self.runner.model_config): mm_kwargs = mm_data + print("just used mm_data") else: + print("calling mapper") mm_kwargs = self.multi_modal_input_mapper( mm_data, seq_group_metadata.mm_processor_kwargs, @@ -731,36 +753,44 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): """Add a sequence group to the builder.""" - seq_ids = seq_group_metadata.seq_data.keys() - n_seqs = len(seq_ids) - is_prompt = seq_group_metadata.is_prompt - - if is_prompt: - assert n_seqs == 1 - self.decode_only = False - encoder_seq_len = 0 - - if self.runner.model_config.is_encoder_decoder: - encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() - - inter_data = self.init_cached_inter_data( - request_id=seq_group_metadata.request_id, - seq_ids=seq_ids, - is_prompt=is_prompt, - block_tables=seq_group_metadata.block_tables, - computed_block_nums=seq_group_metadata.computed_block_nums, - reinit=True, - reinit_use_defaults=True, - encoder_seq_len=encoder_seq_len) - - self.inter_data_list.append(inter_data) - - for seq_idx in range(n_seqs): - for per_seq_fn in self.per_seq_compute_fns: - per_seq_fn(inter_data, seq_idx, seq_group_metadata) - for per_seq_group_fn in self.per_seq_group_compute_fns: - per_seq_group_fn(inter_data, seq_group_metadata) + print("Adding seq group!") + try: + seq_ids = seq_group_metadata.seq_data.keys() + n_seqs = len(seq_ids) + is_prompt = seq_group_metadata.is_prompt + + if is_prompt: + assert n_seqs == 1 + self.decode_only = False + + encoder_seq_len = 0 + + if self.runner.model_config.is_encoder_decoder: + encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() + + inter_data = self.init_cached_inter_data( + request_id=seq_group_metadata.request_id, + seq_ids=seq_ids, + is_prompt=is_prompt, + block_tables=seq_group_metadata.block_tables, + computed_block_nums=seq_group_metadata.computed_block_nums, + reinit=True, + reinit_use_defaults=True, + encoder_seq_len=encoder_seq_len) + + self.inter_data_list.append(inter_data) + + for seq_idx in range(n_seqs): + for per_seq_fn in self.per_seq_compute_fns: + per_seq_fn(inter_data, seq_idx, seq_group_metadata) + for per_seq_group_fn in self.per_seq_group_compute_fns: + print(f"calling {per_seq_group_fn}") + per_seq_group_fn(inter_data, seq_group_metadata) + except Exception as e: + # Raise an exception that tracks the ID of the bad request + raise InputProcessingError(seq_group_metadata.request_id, + str(e)) from e def _use_captured_graph(self, batch_size: int, @@ -1628,6 +1658,9 @@ def prepare_model_input( If cuda graph is required, this API automatically pads inputs. """ + + print("\n\n\n preparing model input \n\n\n\n ") + model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) if get_pp_group().is_last_rank: From 36ac64440f4d8bc354859434e631d215c9c4b620 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 11 Feb 2025 10:05:04 -0700 Subject: [PATCH 2/8] :white_check_mark: Add test for input processing failures Signed-off-by: Joe Runde --- tests/mq_llm_engine/test_error_handling.py | 78 ++++++++++++++++++++++ vllm/engine/llm_engine.py | 61 ++++++++++++++++- vllm/engine/multiprocessing/engine.py | 1 - vllm/worker/model_runner.py | 9 --- 4 files changed, 136 insertions(+), 13 deletions(-) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 35d001781110..aad7fc5303c1 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -18,6 +18,7 @@ from vllm.entrypoints.openai.api_server import build_async_engine_client from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.lora.request import LoRARequest +from vllm.sequence import SequenceGroupMetadata from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser @@ -292,3 +293,80 @@ async def test_engine_process_death(tmp_socket): await client.check_health() client.close() + + +def run_with_evil_input_processing(engine_args: AsyncEngineArgs, + ipc_path: str): + """Simulate an exception while preparing inputs for the model. + In the wild, this could be something like a multimodal input processor + failing on invalid image data.""" + + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + runner = engine.engine.model_executor.driver_worker.worker.model_runner + + # Raise error in the model runner when adding a sequence group. + # See class ModelInputForGPUBuilder + def raiser(_, seq_group_metadata: SequenceGroupMetadata): + if seq_group_metadata.request_id.startswith("evil"): + raise RAISED_ERROR(RAISED_VALUE) + + runner.builder.per_seq_group_compute_fns.append(raiser) + + # Run engine. + engine.start() + + +@pytest.mark.asyncio +async def test_failed_inputs(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_input_processing) as engine: + + client = await engine.make_client() + assert client.is_running + + # Engine should be healthy + await client.check_health() + + async def run_failing_request(): + async for _ in client.generate( + prompt="Hello my name is", + sampling_params=SamplingParams(max_tokens=10), + request_id="evil" + str(uuid.uuid4())): + pass + + async def run_passing_request(): + async for _ in client.generate( + prompt="Hello my name is", + sampling_params=SamplingParams(max_tokens=10), + request_id=str(uuid.uuid4())): + pass + + passing_tasks = [ + asyncio.create_task(run_passing_request()) for _ in range(10) + ] + failing_tasks = [ + asyncio.create_task(run_failing_request()) for _ in range(10) + ] + await asyncio.gather(*failing_tasks, return_exceptions=True) + await asyncio.gather(*passing_tasks) + + # All the bad inputs should have raised + for task in failing_tasks: + with pytest.raises(RAISED_ERROR): + task.result() + + # But all good inputs should have still succeeded + for task in passing_tasks: + task.result() + + # And the engine should remain healthy + assert not client.errored + await client.check_health() + + client.close() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d82d9ad9df32..9d6ed2abf8d1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -405,6 +405,10 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} + # Flag to set when an input fails to process and the engine should run + # the next step without re-scheduling. + self._skip_sheduling_next_step = False + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -1329,7 +1333,9 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: # Skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current # batch has completed. - if not self._has_remaining_steps(seq_group_metadata_list): + if not self._has_remaining_steps( + seq_group_metadata_list + ) and not self._skip_sheduling_next_step: # Schedule iteration (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc @@ -1383,8 +1389,25 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: execute_model_req.async_callback = self.async_callbacks[ virtual_engine] - outputs = self.model_executor.execute_model( - execute_model_req=execute_model_req) + # todo: import loop + from vllm.worker.model_runner import InputProcessingError + try: + outputs = self.model_executor.execute_model( + execute_model_req=execute_model_req) + self._skip_sheduling_next_step = False + except InputProcessingError as e: + # The input for this request cannot be processed, so we must + # abort it. If there are remaining requests in the batch that + # have been scheduled, they will be retried on the next step. + invalid_request_id = e.request_id + self._abort_and_cache_schedule( + request_id=invalid_request_id, + virtual_engine=virtual_engine, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + allow_async_output_proc=allow_async_output_proc) + # Raise so the caller is notified that this request failed + raise # We need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. @@ -1459,6 +1482,38 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: return ctx.request_outputs + def _abort_and_cache_schedule( + self, request_id: str, virtual_engine: int, + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduler_outputs: SchedulerOutputs, + allow_async_output_proc: bool) -> None: + """Aborts a single request, and caches the scheduler outputs minus that + request. This allows the next step to continue processing the remaining + requests without having to re-run the scheduler.""" + + # Abort the request and remove its sequence group from the current + # schedule + self.abort_request(request_id) + for i, metadata in enumerate(seq_group_metadata_list): + if metadata.request_id == request_id: + del seq_group_metadata_list[i] + break + for i, group in enumerate(scheduler_outputs.scheduled_seq_groups): + if group.seq_group.request_id == request_id: + del scheduler_outputs.scheduled_seq_groups[i] + break + + # If there are still other sequence groups left in the schedule, cache + # them and flag the engine to reuse the schedule. + if len(seq_group_metadata_list) > 0: + self._skip_sheduling_next_step = True + # Reuse multi-step caching logic + self._cache_scheduler_outputs_for_multi_step( + virtual_engine=virtual_engine, + scheduler_outputs=scheduler_outputs, + seq_group_metadata_list=seq_group_metadata_list, + allow_async_output_proc=allow_async_output_proc) + def _has_remaining_steps( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] ) -> bool: diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 2074162164c6..97867f4fb5fd 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -273,7 +273,6 @@ def _handle_process_request(self, request: RPCProcessRequest): self._send_outputs(rpc_err) try: - print("\n Running add_request \n") self.engine.add_request( request_id=request_id, prompt=request.prompt, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 222e1039def2..7bea5e99c0d1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -701,13 +701,9 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, if not mm_data: return - print("FOOOOOOOO") - if self.runner.mm_registry.has_processor(self.runner.model_config): mm_kwargs = mm_data - print("just used mm_data") else: - print("calling mapper") mm_kwargs = self.multi_modal_input_mapper( mm_data, seq_group_metadata.mm_processor_kwargs, @@ -754,7 +750,6 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): """Add a sequence group to the builder.""" - print("Adding seq group!") try: seq_ids = seq_group_metadata.seq_data.keys() n_seqs = len(seq_ids) @@ -785,7 +780,6 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): for per_seq_fn in self.per_seq_compute_fns: per_seq_fn(inter_data, seq_idx, seq_group_metadata) for per_seq_group_fn in self.per_seq_group_compute_fns: - print(f"calling {per_seq_group_fn}") per_seq_group_fn(inter_data, seq_group_metadata) except Exception as e: # Raise an exception that tracks the ID of the bad request @@ -1658,9 +1652,6 @@ def prepare_model_input( If cuda graph is required, this API automatically pads inputs. """ - - print("\n\n\n preparing model input \n\n\n\n ") - model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) if get_pp_group().is_last_rank: From 54cd2575b2b047bec65ce101f8db7ccc5304c00e Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 11 Feb 2025 10:07:13 -0700 Subject: [PATCH 3/8] :art: cleanup Signed-off-by: Joe Runde --- vllm/engine/multiprocessing/engine.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 97867f4fb5fd..93765b94d629 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -286,8 +286,6 @@ def _handle_process_request(self, request: RPCProcessRequest): logger.info("Added request %s.", request.request_id) except Exception as e: - print("\n Caught add_request failure\n") - # We do not set self._errored = True here, since the error # is due to an issue adding this request to the engine, # rather than an issue with the engine itself. From aeaf2eba225007af026293748f6911847f3e33f4 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 11 Feb 2025 10:10:13 -0700 Subject: [PATCH 4/8] :fire: cleanup abort Signed-off-by: Joe Runde --- vllm/engine/multiprocessing/engine.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 93765b94d629..f2a8b1f1a6a8 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -217,9 +217,6 @@ def engine_step(self) -> List[RequestOutput]: is_engine_errored=False, exception=e.__cause__) self._send_outputs(rpc_err) - # We must abort the request immediately so that the engine does not - # try to continue to process it in the next step - self.engine.abort_request(e.request_id) except BaseException as e: self._set_errored(e) rpc_err = RPCError(request_id=None, From 234a8269338c77122c2f0ec321d7f06d8c0c8d0b Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 11 Feb 2025 11:51:37 -0700 Subject: [PATCH 5/8] :art: explicitly return empty outputs on failure Signed-off-by: Joe Runde --- vllm/engine/multiprocessing/engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index f2a8b1f1a6a8..d5a7768ccd8e 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -217,6 +217,7 @@ def engine_step(self) -> List[RequestOutput]: is_engine_errored=False, exception=e.__cause__) self._send_outputs(rpc_err) + return [] except BaseException as e: self._set_errored(e) rpc_err = RPCError(request_id=None, From 49ad2b1b7f37e16854483fa5bf0295c7f1e34802 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 11 Feb 2025 12:26:09 -0700 Subject: [PATCH 6/8] :recycle: move input processing error def Signed-off-by: Joe Runde --- vllm/engine/llm_engine.py | 3 +-- vllm/engine/multiprocessing/engine.py | 2 +- vllm/worker/model_runner.py | 22 ++-------------------- vllm/worker/model_runner_base.py | 18 ++++++++++++++++++ 4 files changed, 22 insertions(+), 23 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9d6ed2abf8d1..e57cd354e5bf 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -60,6 +60,7 @@ usage_message) from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind from vllm.version import __version__ as VLLM_VERSION +from vllm.worker.model_runner_base import InputProcessingError logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -1389,8 +1390,6 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: execute_model_req.async_callback = self.async_callbacks[ virtual_engine] - # todo: import loop - from vllm.worker.model_runner import InputProcessingError try: outputs = self.model_executor.execute_model( execute_model_req=execute_model_req) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index d5a7768ccd8e..1cd6ca5b1c49 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -26,7 +26,7 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext -from vllm.worker.model_runner import InputProcessingError +from vllm.worker.model_runner_base import InputProcessingError logger = init_logger(__name__) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7bea5e99c0d1..e588576a92e1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -53,8 +53,8 @@ is_pin_memory_available, supports_dynamo, weak_ref_tensor) from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, - _add_attn_metadata_broadcastable_dict, + InputProcessingError, ModelRunnerBase, ModelRunnerInputBase, + ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) @@ -75,24 +75,6 @@ torch._dynamo.config.accumulated_cache_size_limit = 128 -class InputProcessingError(Exception): - """This exception is raised when an error occurs preparing the inputs for - a single sequence group. - This allows the engine to gracefully handle errors with a single sequence - group without having to fail the entire batch. - """ - - def __init__(self, request_id, message): - """request_id is the id of the offending sequence group""" - self.request_id = request_id - self.message = message - super().__init__(self.message) - - def __str__(self): - return "Failed to prepare inputs for sequence group with request id: " \ - f"{self.request_id}, Error: {self.message}" - - @dataclass(frozen=True) class ModelInputForGPU(ModelRunnerInputBase): """ diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 38d2b712eff5..670cc032a0df 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -258,3 +258,21 @@ def __init__( def __getattr__(self, attr): return getattr(self.model_runner, attr) + + +class InputProcessingError(Exception): + """This exception is raised when an error occurs preparing the inputs for + a single sequence group. + This allows the engine to gracefully handle errors with a single sequence + group without having to fail the entire batch. + """ + + def __init__(self, request_id, message): + """request_id is the id of the offending sequence group""" + self.request_id = request_id + self.message = message + super().__init__(self.message) + + def __str__(self): + return "Failed to prepare inputs for sequence group with request id: " \ + f"{self.request_id}, Error: {self.message}" From e90af369158f9b16e946289fc141b4f1e4965b21 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 12 Feb 2025 10:27:19 -0700 Subject: [PATCH 7/8] :recycle: fix name and add comment Signed-off-by: Joe Runde --- vllm/engine/llm_engine.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e57cd354e5bf..aaa194e2a7d1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -408,7 +408,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: # Flag to set when an input fails to process and the engine should run # the next step without re-scheduling. - self._skip_sheduling_next_step = False + self._skip_scheduling_next_step = False def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -1334,9 +1334,11 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: # Skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current # batch has completed. + # The scheduler is also skipped if a single request caused the last + # engine step to fail, and the previous schedule needs to be rerun. if not self._has_remaining_steps( seq_group_metadata_list - ) and not self._skip_sheduling_next_step: + ) and not self._skip_scheduling_next_step: # Schedule iteration (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc @@ -1393,7 +1395,7 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: try: outputs = self.model_executor.execute_model( execute_model_req=execute_model_req) - self._skip_sheduling_next_step = False + self._skip_scheduling_next_step = False except InputProcessingError as e: # The input for this request cannot be processed, so we must # abort it. If there are remaining requests in the batch that @@ -1505,7 +1507,7 @@ def _abort_and_cache_schedule( # If there are still other sequence groups left in the schedule, cache # them and flag the engine to reuse the schedule. if len(seq_group_metadata_list) > 0: - self._skip_sheduling_next_step = True + self._skip_scheduling_next_step = True # Reuse multi-step caching logic self._cache_scheduler_outputs_for_multi_step( virtual_engine=virtual_engine, From 5ba9c4c7ce0ae738591be7d3f1b7ce64e82e4d30 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 25 Feb 2025 11:49:25 -0700 Subject: [PATCH 8/8] :recycle: move try/catch to be more concise Signed-off-by: Joe Runde --- vllm/worker/model_runner.py | 71 ++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e588576a92e1..4dd5d8e90c10 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -731,42 +731,36 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): """Add a sequence group to the builder.""" + seq_ids = seq_group_metadata.seq_data.keys() + n_seqs = len(seq_ids) + is_prompt = seq_group_metadata.is_prompt - try: - seq_ids = seq_group_metadata.seq_data.keys() - n_seqs = len(seq_ids) - is_prompt = seq_group_metadata.is_prompt - - if is_prompt: - assert n_seqs == 1 - self.decode_only = False - - encoder_seq_len = 0 - - if self.runner.model_config.is_encoder_decoder: - encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() - - inter_data = self.init_cached_inter_data( - request_id=seq_group_metadata.request_id, - seq_ids=seq_ids, - is_prompt=is_prompt, - block_tables=seq_group_metadata.block_tables, - computed_block_nums=seq_group_metadata.computed_block_nums, - reinit=True, - reinit_use_defaults=True, - encoder_seq_len=encoder_seq_len) - - self.inter_data_list.append(inter_data) - - for seq_idx in range(n_seqs): - for per_seq_fn in self.per_seq_compute_fns: - per_seq_fn(inter_data, seq_idx, seq_group_metadata) - for per_seq_group_fn in self.per_seq_group_compute_fns: - per_seq_group_fn(inter_data, seq_group_metadata) - except Exception as e: - # Raise an exception that tracks the ID of the bad request - raise InputProcessingError(seq_group_metadata.request_id, - str(e)) from e + if is_prompt: + assert n_seqs == 1 + self.decode_only = False + + encoder_seq_len = 0 + + if self.runner.model_config.is_encoder_decoder: + encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() + + inter_data = self.init_cached_inter_data( + request_id=seq_group_metadata.request_id, + seq_ids=seq_ids, + is_prompt=is_prompt, + block_tables=seq_group_metadata.block_tables, + computed_block_nums=seq_group_metadata.computed_block_nums, + reinit=True, + reinit_use_defaults=True, + encoder_seq_len=encoder_seq_len) + + self.inter_data_list.append(inter_data) + + for seq_idx in range(n_seqs): + for per_seq_fn in self.per_seq_compute_fns: + per_seq_fn(inter_data, seq_idx, seq_group_metadata) + for per_seq_group_fn in self.per_seq_group_compute_fns: + per_seq_group_fn(inter_data, seq_group_metadata) def _use_captured_graph(self, batch_size: int, @@ -1222,7 +1216,12 @@ def _prepare_model_input_tensors( """ self.builder.prepare(finished_requests_ids) for seq_group_metadata in seq_group_metadata_list: - self.builder.add_seq_group(seq_group_metadata) + try: + self.builder.add_seq_group(seq_group_metadata) + except Exception as e: + # Raise an exception that tracks the ID of the bad request + raise InputProcessingError(seq_group_metadata.request_id, + str(e)) from e self.builder.reset_cached_inter_data()