Skip to content

Commit 7c65527

Browse files
authored
[V1] Use pickle for serializing EngineCoreRequest & Add multimodal inputs to EngineCoreRequest (#10245)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 47db6ec commit 7c65527

File tree

5 files changed

+25
-5
lines changed

5 files changed

+25
-5
lines changed

vllm/v1/engine/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import enum
22
from dataclasses import dataclass
3-
from typing import List, Optional, Union
3+
from typing import Any, Dict, List, Optional, Union
44

55
import msgspec
66

77
from vllm.lora.request import LoRARequest
8+
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
89
from vllm.sampling_params import RequestOutputKind, SamplingParams
910

1011

@@ -22,7 +23,8 @@ class DetokenizerRequest:
2223
include_stop_str_in_output: bool
2324

2425

25-
class EngineCoreRequest(msgspec.Struct, omit_defaults=True):
26+
@dataclass
27+
class EngineCoreRequest:
2628

2729
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
2830
# but this object is currently not playing well with msgspec
@@ -33,6 +35,9 @@ class EngineCoreRequest(msgspec.Struct, omit_defaults=True):
3335
# always be tokenized?
3436
prompt: Optional[str]
3537
prompt_token_ids: List[int]
38+
mm_data: Optional[MultiModalDataDict]
39+
mm_placeholders: Optional[MultiModalPlaceholderDict]
40+
mm_processor_kwargs: Optional[Dict[str, Any]]
3641
sampling_params: SamplingParams
3742
eos_token_id: Optional[int]
3843
arrival_time: float

vllm/v1/engine/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
EngineCoreRequest, EngineCoreRequestType)
2020
from vllm.v1.executor.gpu_executor import GPUExecutor
2121
from vllm.v1.request import Request, RequestStatus
22+
from vllm.v1.serial_utils import PickleEncoder
2223
from vllm.version import __version__ as VLLM_VERSION
2324

2425
logger = init_logger(__name__)
@@ -315,7 +316,7 @@ def process_input_socket(self, input_path: str):
315316
"""Input socket IO thread."""
316317

317318
# Msgpack serialization decoding.
318-
decoder_add_req = msgpack.Decoder(EngineCoreRequest)
319+
decoder_add_req = PickleEncoder()
319320
decoder_abort_req = msgpack.Decoder(list[str])
320321

321322
with self.make_socket(input_path, zmq.constants.PULL) as socket:

vllm/v1/engine/core_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
1212
EngineCoreRequest, EngineCoreRequestType)
1313
from vllm.v1.engine.core import EngineCore, EngineCoreProc
14+
from vllm.v1.serial_utils import PickleEncoder
1415

1516
logger = init_logger(__name__)
1617

@@ -115,7 +116,7 @@ def __init__(
115116
**kwargs,
116117
):
117118
# Serialization setup.
118-
self.encoder = msgspec.msgpack.Encoder()
119+
self.encoder = PickleEncoder()
119120
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)
120121

121122
# ZMQ setup.

vllm/v1/engine/processor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,10 @@ def process_inputs(
9191
# Make Request for EngineCore.
9292
engine_core_request = EngineCoreRequest(
9393
request_id, processed_inputs.get("prompt"),
94-
processed_inputs.get("prompt_token_ids"), sampling_params,
94+
processed_inputs.get("prompt_token_ids"),
95+
processed_inputs.get("multi_modal_data"),
96+
processed_inputs.get("multi_modal_placeholders"),
97+
processed_inputs.get("mm_processor_kwargs"), sampling_params,
9598
eos_token_id, arrival_time, lora_request)
9699

97100
return detokenizer_request, engine_core_request

vllm/v1/serial_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import pickle
2+
3+
4+
class PickleEncoder:
5+
6+
def encode(self, obj):
7+
return pickle.dumps(obj)
8+
9+
def decode(self, data):
10+
return pickle.loads(data)

0 commit comments

Comments
 (0)