|
60 | 60 | from vllm.utils import (Counter, Device, deprecate_kwargs, |
61 | 61 | resolve_obj_by_qualname, weak_bind) |
62 | 62 | from vllm.version import __version__ as VLLM_VERSION |
| 63 | +from vllm.worker.model_runner_base import InputProcessingError |
63 | 64 |
|
64 | 65 | logger = init_logger(__name__) |
65 | 66 | _LOCAL_LOGGING_INTERVAL_SEC = 5 |
@@ -410,6 +411,10 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: |
410 | 411 |
|
411 | 412 | self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} |
412 | 413 |
|
| 414 | + # Flag to set when an input fails to process and the engine should run |
| 415 | + # the next step without re-scheduling. |
| 416 | + self._skip_scheduling_next_step = False |
| 417 | + |
413 | 418 | def _initialize_kv_caches(self) -> None: |
414 | 419 | """Initialize the KV cache in the worker(s). |
415 | 420 |
|
@@ -1334,7 +1339,11 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: |
1334 | 1339 | # Skip the scheduler if there are any remaining steps in the seq groups. |
1335 | 1340 | # This ensures that the scheduler is only called again when the current |
1336 | 1341 | # batch has completed. |
1337 | | - if not self._has_remaining_steps(seq_group_metadata_list): |
| 1342 | + # The scheduler is also skipped if a single request caused the last |
| 1343 | + # engine step to fail, and the previous schedule needs to be rerun. |
| 1344 | + if not self._has_remaining_steps( |
| 1345 | + seq_group_metadata_list |
| 1346 | + ) and not self._skip_scheduling_next_step: |
1338 | 1347 | # Schedule iteration |
1339 | 1348 | (seq_group_metadata_list, scheduler_outputs, |
1340 | 1349 | allow_async_output_proc |
@@ -1388,8 +1397,23 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: |
1388 | 1397 | execute_model_req.async_callback = self.async_callbacks[ |
1389 | 1398 | virtual_engine] |
1390 | 1399 |
|
1391 | | - outputs = self.model_executor.execute_model( |
1392 | | - execute_model_req=execute_model_req) |
| 1400 | + try: |
| 1401 | + outputs = self.model_executor.execute_model( |
| 1402 | + execute_model_req=execute_model_req) |
| 1403 | + self._skip_scheduling_next_step = False |
| 1404 | + except InputProcessingError as e: |
| 1405 | + # The input for this request cannot be processed, so we must |
| 1406 | + # abort it. If there are remaining requests in the batch that |
| 1407 | + # have been scheduled, they will be retried on the next step. |
| 1408 | + invalid_request_id = e.request_id |
| 1409 | + self._abort_and_cache_schedule( |
| 1410 | + request_id=invalid_request_id, |
| 1411 | + virtual_engine=virtual_engine, |
| 1412 | + seq_group_metadata_list=seq_group_metadata_list, |
| 1413 | + scheduler_outputs=scheduler_outputs, |
| 1414 | + allow_async_output_proc=allow_async_output_proc) |
| 1415 | + # Raise so the caller is notified that this request failed |
| 1416 | + raise |
1393 | 1417 |
|
1394 | 1418 | # We need to do this here so that last step's sampled_token_ids can |
1395 | 1419 | # be passed to the next iteration for PP. |
@@ -1464,6 +1488,38 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: |
1464 | 1488 |
|
1465 | 1489 | return ctx.request_outputs |
1466 | 1490 |
|
| 1491 | + def _abort_and_cache_schedule( |
| 1492 | + self, request_id: str, virtual_engine: int, |
| 1493 | + seq_group_metadata_list: List[SequenceGroupMetadata], |
| 1494 | + scheduler_outputs: SchedulerOutputs, |
| 1495 | + allow_async_output_proc: bool) -> None: |
| 1496 | + """Aborts a single request, and caches the scheduler outputs minus that |
| 1497 | + request. This allows the next step to continue processing the remaining |
| 1498 | + requests without having to re-run the scheduler.""" |
| 1499 | + |
| 1500 | + # Abort the request and remove its sequence group from the current |
| 1501 | + # schedule |
| 1502 | + self.abort_request(request_id) |
| 1503 | + for i, metadata in enumerate(seq_group_metadata_list): |
| 1504 | + if metadata.request_id == request_id: |
| 1505 | + del seq_group_metadata_list[i] |
| 1506 | + break |
| 1507 | + for i, group in enumerate(scheduler_outputs.scheduled_seq_groups): |
| 1508 | + if group.seq_group.request_id == request_id: |
| 1509 | + del scheduler_outputs.scheduled_seq_groups[i] |
| 1510 | + break |
| 1511 | + |
| 1512 | + # If there are still other sequence groups left in the schedule, cache |
| 1513 | + # them and flag the engine to reuse the schedule. |
| 1514 | + if len(seq_group_metadata_list) > 0: |
| 1515 | + self._skip_scheduling_next_step = True |
| 1516 | + # Reuse multi-step caching logic |
| 1517 | + self._cache_scheduler_outputs_for_multi_step( |
| 1518 | + virtual_engine=virtual_engine, |
| 1519 | + scheduler_outputs=scheduler_outputs, |
| 1520 | + seq_group_metadata_list=seq_group_metadata_list, |
| 1521 | + allow_async_output_proc=allow_async_output_proc) |
| 1522 | + |
1467 | 1523 | def _has_remaining_steps( |
1468 | 1524 | self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] |
1469 | 1525 | ) -> bool: |
|
0 commit comments