Skip to content

Commit d43f914

Browse files
authored
[Core][Feature] Input metadata dump on crash (#13407)
Signed-off-by: Wallas Santos <wallashss@ibm.com>
1 parent ed5272c commit d43f914

File tree

5 files changed

+169
-9
lines changed

5 files changed

+169
-9
lines changed

.github/ISSUE_TEMPLATE/400-bug-report.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ body:
7575
```
7676
7777
```
78-
The error message you got, with the full traceback.
78+
The error message you got, with the full traceback and the error logs with [dump_input.py:##] if present.
7979
```
8080
validations:
8181
required: true

tests/basic_correctness/test_basic_correctness.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
"""
66
import os
77
import weakref
8+
from unittest.mock import Mock
89

910
import pytest
1011

1112
from vllm import LLM
1213
from vllm.platforms import current_platform
14+
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
1315

1416
from ..conftest import VllmRunner
1517
from ..models.utils import check_outputs_equal
@@ -152,9 +154,44 @@ def test_models_distributed(
152154
with hf_runner(model, dtype=dtype) as hf_model:
153155
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
154156

155-
check_outputs_equal(
156-
outputs_0_lst=hf_outputs,
157-
outputs_1_lst=vllm_outputs,
158-
name_0="hf",
159-
name_1="vllm",
160-
)
157+
check_outputs_equal(
158+
outputs_0_lst=hf_outputs,
159+
outputs_1_lst=vllm_outputs,
160+
name_0="hf",
161+
name_1="vllm",
162+
)
163+
164+
165+
def test_failed_model_execution(vllm_runner, monkeypatch) -> None:
166+
167+
from vllm.envs import VLLM_USE_V1
168+
169+
if not VLLM_USE_V1:
170+
pytest.skip("Skipping V0 test, dump input not supported")
171+
172+
# Needed to mock an error in the same process
173+
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
174+
175+
with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model:
176+
if isinstance(vllm_model.model.llm_engine, LLMEngineV1):
177+
v1_test_failed_model_execution(vllm_model)
178+
179+
180+
def v1_test_failed_model_execution(vllm_model):
181+
182+
engine = vllm_model.model.llm_engine
183+
mocked_execute_model = Mock(
184+
side_effect=RuntimeError("Mocked Critical Error"))
185+
engine.engine_core.engine_core.model_executor.execute_model =\
186+
mocked_execute_model
187+
188+
with pytest.raises(RuntimeError) as exc_info:
189+
prompts = [
190+
"Hello, my name is",
191+
"The president of the United States is",
192+
"The capital of France is",
193+
"The future of AI is",
194+
]
195+
vllm_model.generate_greedy(prompts, 200, use_tqdm=False)
196+
assert isinstance(exc_info.value, RuntimeError)
197+
assert "Mocked Critical Error" in str(exc_info.value)

vllm/logging_utils/dump_input.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import contextlib
4+
import enum
5+
import json
6+
from typing import Optional
7+
8+
import torch
9+
10+
from vllm.config import VllmConfig
11+
from vllm.logger import init_logger
12+
from vllm.v1.core.sched.output import SchedulerOutput
13+
from vllm.v1.metrics.stats import SchedulerStats
14+
from vllm.version import __version__ as VLLM_VERSION
15+
16+
logger = init_logger(__name__)
17+
18+
19+
def prepare_object_to_dump(obj) -> str:
20+
if isinstance(obj, str):
21+
return "'{obj}'" # Double quotes
22+
elif isinstance(obj, dict):
23+
dict_str = ', '.join({f'{str(k)}: {prepare_object_to_dump(v)}' \
24+
for k, v in obj.items()})
25+
return f'{{{dict_str}}}'
26+
elif isinstance(obj, list):
27+
return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]"
28+
elif isinstance(obj, set):
29+
return f"[{', '.join([prepare_object_to_dump(v) for v in list(obj)])}]"
30+
# return [prepare_object_to_dump(v) for v in list(obj)]
31+
elif isinstance(obj, tuple):
32+
return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]"
33+
elif isinstance(obj, enum.Enum):
34+
return repr(obj)
35+
elif isinstance(obj, torch.Tensor):
36+
# We only print the 'draft' of the tensor to not expose sensitive data
37+
# and to get some metadata in case of CUDA runtime crashed
38+
return (f"Tensor(shape={obj.shape}, "
39+
f"device={obj.device},"
40+
f"dtype={obj.dtype})")
41+
elif hasattr(obj, 'anon_repr'):
42+
return obj.anon_repr()
43+
elif hasattr(obj, '__dict__'):
44+
items = obj.__dict__.items()
45+
dict_str = ','.join([f'{str(k)}={prepare_object_to_dump(v)}' \
46+
for k, v in items])
47+
return (f"{type(obj).__name__}({dict_str})")
48+
else:
49+
# Hacky way to make sure we can serialize the object in JSON format
50+
try:
51+
return json.dumps(obj)
52+
except (TypeError, OverflowError):
53+
return repr(obj)
54+
55+
56+
def dump_engine_exception(config: VllmConfig,
57+
scheduler_output: SchedulerOutput,
58+
scheduler_stats: Optional[SchedulerStats]):
59+
# NOTE: ensure we can log extra info without risking raises
60+
# unexpected errors during logging
61+
with contextlib.suppress(BaseException):
62+
_dump_engine_exception(config, scheduler_output, scheduler_stats)
63+
64+
65+
def _dump_engine_exception(config: VllmConfig,
66+
scheduler_output: SchedulerOutput,
67+
scheduler_stats: Optional[SchedulerStats]):
68+
logger.error("Dumping input data")
69+
70+
logger.error(
71+
"V1 LLM engine (v%s) with config: %s, ",
72+
VLLM_VERSION,
73+
config,
74+
)
75+
76+
try:
77+
dump_obj = prepare_object_to_dump(scheduler_output)
78+
logger.error("Dumping scheduler output for model execution:")
79+
logger.error(dump_obj)
80+
if scheduler_stats:
81+
logger.error(scheduler_stats)
82+
except BaseException as exception:
83+
logger.error("Error preparing object to dump")
84+
logger.error(repr(exception))

vllm/v1/core/sched/output.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,33 @@ def from_request(
4848
lora_request=request.lora_request,
4949
)
5050

51+
def __repr__(self):
52+
return (f"NewRequestData("
53+
f"req_id={self.req_id},"
54+
f"prompt_token_ids={self.prompt_token_ids},"
55+
f"mm_inputs={self.mm_inputs},"
56+
f"mm_hashes={self.mm_hashes},"
57+
f"mm_positions={self.mm_positions},"
58+
f"sampling_params={self.sampling_params},"
59+
f"block_ids={self.block_ids},"
60+
f"num_computed_tokens={self.num_computed_tokens},"
61+
f"lora_request={self.lora_request}"
62+
")")
63+
64+
# Version of __repr__ with the prompt data obfuscated
65+
def anon_repr(self):
66+
return (f"NewRequestData("
67+
f"req_id={self.req_id},"
68+
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
69+
f"mm_inputs={self.mm_inputs},"
70+
f"mm_hashes={self.mm_hashes},"
71+
f"mm_positions={self.mm_positions},"
72+
f"sampling_params={self.sampling_params},"
73+
f"block_ids={self.block_ids},"
74+
f"num_computed_tokens={self.num_computed_tokens},"
75+
f"lora_request={self.lora_request}"
76+
")")
77+
5178

5279
@dataclass
5380
class CachedRequestData:

vllm/v1/engine/core.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vllm.distributed import stateless_destroy_torch_distributed_process_group
2020
from vllm.executor.multiproc_worker_utils import _add_prefix
2121
from vllm.logger import init_logger
22+
from vllm.logging_utils.dump_input import dump_engine_exception
2223
from vllm.lora.request import LoRARequest
2324
from vllm.transformers_utils.config import (
2425
maybe_register_config_serialize_by_value)
@@ -56,6 +57,7 @@ def __init__(self,
5657
executor_fail_callback: Optional[Callable] = None):
5758
assert vllm_config.model_config.runner_type != "pooling"
5859

60+
self.vllm_config = vllm_config
5961
logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
6062
VLLM_VERSION, vllm_config)
6163

@@ -191,6 +193,16 @@ def abort_requests(self, request_ids: list[str]):
191193
self.scheduler.finish_requests(request_ids,
192194
RequestStatus.FINISHED_ABORTED)
193195

196+
def execute_model(self, scheduler_output: SchedulerOutput):
197+
try:
198+
return self.model_executor.execute_model(scheduler_output)
199+
except BaseException as err:
200+
# NOTE: This method is exception-free
201+
dump_engine_exception(self.vllm_config, scheduler_output,
202+
self.scheduler.make_stats())
203+
# Re-raise exception
204+
raise err
205+
194206
def step(self) -> EngineCoreOutputs:
195207
"""Schedule, execute, and make output."""
196208

@@ -202,9 +214,9 @@ def step(self) -> EngineCoreOutputs:
202214
scheduler_stats=self.scheduler.make_stats(),
203215
)
204216
scheduler_output = self.scheduler.schedule()
205-
output = self.model_executor.execute_model(scheduler_output)
217+
model_output = self.execute_model(scheduler_output)
206218
engine_core_outputs = self.scheduler.update_from_output(
207-
scheduler_output, output) # type: ignore
219+
scheduler_output, model_output) # type: ignore
208220

209221
return engine_core_outputs
210222

0 commit comments

Comments
 (0)