Skip to content

Commit

Permalink
[MISC] Dump model runner inputs when crashing (vllm-project#8305)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored and MengqingCao committed Sep 30, 2024
1 parent 4cbccb1 commit db0f4f1
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 1 deletion.
9 changes: 9 additions & 0 deletions .github/ISSUE_TEMPLATE/400-bug report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ body:
</details>
validations:
required: true
- type: textarea
attributes:
label: Model Input Dumps
description: |
If you are facing crashing due to illegal memory access or other issues with model execution, vLLM may dump the problematic input of the model. In this case, you will see the message `Error in model execution (input dumped to /tmp/err_xxx.pkl)`. If you see this message, please zip the file (because GitHub doesn't support .pkl file format) and upload it here. This will help us to reproduce the issue and facilitate the debugging process.
placeholder: |
Upload the dumped input file.
validations:
required: false
- type: textarea
attributes:
label: 🐛 Describe the bug
Expand Down
30 changes: 30 additions & 0 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
"""
import os
import pickle
import re
import weakref
from unittest.mock import patch

import pytest

from vllm import LLM
from vllm.utils import is_hip
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

from ..models.utils import check_outputs_equal

Expand Down Expand Up @@ -67,3 +71,29 @@ def test_models(
name_0="hf",
name_1="vllm",
)


def test_model_with_failure(vllm_runner) -> None:
try:
with patch("vllm.model_executor.models.opt.OPTForCausalLM.forward",
side_effect=ValueError()):
with pytest.raises(ValueError) as exc_info:
vllm_runner("facebook/opt-125m",
dtype="half",
enforce_eager=False,
gpu_memory_utilization=0.7)
matches = re.search(r"input dumped to (.+).pkl",
str(exc_info.value))
assert matches is not None
filename = f"{matches.group(1)}.pkl"

with open(filename, "rb") as filep:
inputs = pickle.load(filep)

if any(key not in inputs for key in ("arg_1", "arg_2", "arg_3")):
raise AssertionError("Missing keys in dumped inputs. Dumped keys: "
f"{list(inputs.keys())}")
assert isinstance(inputs["arg_1"],
ModelInputForGPUWithSamplingMetadata)
finally:
os.remove(filename)
3 changes: 2 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
_init_sampling_metadata_from_tensor_dict, dump_input_when_exception)

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
Expand Down Expand Up @@ -1489,6 +1489,7 @@ def prepare_model_input(
virtual_engine=virtual_engine)

@torch.inference_mode()
@dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"])
def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
Expand Down
34 changes: 34 additions & 0 deletions vllm/worker/model_runner_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import dataclasses
import pickle
from abc import ABC, abstractmethod
from datetime import datetime
from functools import wraps
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar)

Expand Down Expand Up @@ -98,6 +101,37 @@ def _init_frozen_model_input_from_tensor_dict(
return tensor_dict


def dump_input_when_exception(exclude_args: Optional[List[int]] = None,
exclude_kwargs: Optional[List[str]] = None):

def _inner(func):

@wraps(func)
def _wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as err:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl"
with open(filename, "wb") as filep:
dumped_inputs = {
k: v
for k, v in kwargs.items()
if k not in (exclude_kwargs or [])
}
for i, arg in enumerate(args):
if i not in (exclude_args or []):
dumped_inputs[f"arg_{i}"] = arg
pickle.dump(dumped_inputs, filep)
raise type(err)(
f"Error in model execution (input dumped to {filename}): "
f"{str(err)}") from err

return _wrapper

return _inner


class BroadcastableModelInput(ABC):

@abstractmethod
Expand Down

0 comments on commit db0f4f1

Please sign in to comment.