Skip to content

Commit 60371ed

Browse files
njhillxuebwang-amd
authored andcommitted
[Misc] Support more collective_rpc return types (vllm-project#25294)
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 4637e4b commit 60371ed

File tree

2 files changed

+246
-17
lines changed

2 files changed

+246
-17
lines changed

tests/v1/engine/test_engine_core_client.py

Lines changed: 202 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import uuid
99
from dataclasses import dataclass
1010
from threading import Thread
11-
from typing import Optional, Union
11+
from typing import Any, Optional, Union
1212
from unittest.mock import MagicMock
1313

1414
import pytest
@@ -331,6 +331,46 @@ def echo_dc(
331331
return [val for _ in range(3)] if return_list else val
332332

333333

334+
# Dummy utility function to test dict serialization with custom types.
335+
def echo_dc_dict(
336+
self,
337+
msg: str,
338+
return_dict: bool = False,
339+
) -> Union[MyDataclass, dict[str, MyDataclass]]:
340+
print(f"echo dc dict util function called: {msg}")
341+
val = None if msg is None else MyDataclass(msg)
342+
# Return dict of dataclasses to verify support for returning dicts
343+
# with custom value types.
344+
if return_dict:
345+
return {"key1": val, "key2": val, "key3": val}
346+
else:
347+
return val
348+
349+
350+
# Dummy utility function to test nested structures with custom types.
351+
def echo_dc_nested(
352+
self,
353+
msg: str,
354+
structure_type: str = "list_of_dicts",
355+
) -> Any:
356+
print(f"echo dc nested util function called: {msg}, "
357+
f"structure: {structure_type}")
358+
val = None if msg is None else MyDataclass(msg)
359+
360+
if structure_type == "list_of_dicts": # noqa
361+
# Return list of dicts: [{"a": val, "b": val}, {"c": val, "d": val}]
362+
return [{"a": val, "b": val}, {"c": val, "d": val}]
363+
elif structure_type == "dict_of_lists":
364+
# Return dict of lists: {"list1": [val, val], "list2": [val, val]}
365+
return {"list1": [val, val], "list2": [val, val]}
366+
elif structure_type == "deep_nested":
367+
# Return deeply nested: {"outer": [{"inner": [val, val]},
368+
# {"inner": [val]}]}
369+
return {"outer": [{"inner": [val, val]}, {"inner": [val]}]}
370+
else:
371+
return val
372+
373+
334374
@pytest.mark.asyncio(loop_scope="function")
335375
async def test_engine_core_client_util_method_custom_return(
336376
monkeypatch: pytest.MonkeyPatch):
@@ -384,6 +424,167 @@ async def test_engine_core_client_util_method_custom_return(
384424
client.shutdown()
385425

386426

427+
@pytest.mark.asyncio(loop_scope="function")
428+
async def test_engine_core_client_util_method_custom_dict_return(
429+
monkeypatch: pytest.MonkeyPatch):
430+
431+
with monkeypatch.context() as m:
432+
m.setenv("VLLM_USE_V1", "1")
433+
434+
# Must set insecure serialization to allow returning custom types.
435+
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
436+
437+
# Monkey-patch core engine utility function to test.
438+
m.setattr(EngineCore, "echo_dc_dict", echo_dc_dict, raising=False)
439+
440+
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
441+
vllm_config = engine_args.create_engine_config(
442+
usage_context=UsageContext.UNKNOWN_CONTEXT)
443+
executor_class = Executor.get_class(vllm_config)
444+
445+
with set_default_torch_num_threads(1):
446+
client = EngineCoreClient.make_client(
447+
multiprocess_mode=True,
448+
asyncio_mode=True,
449+
vllm_config=vllm_config,
450+
executor_class=executor_class,
451+
log_stats=True,
452+
)
453+
454+
try:
455+
# Test utility method returning custom / non-native data type.
456+
core_client: AsyncMPClient = client
457+
458+
# Test single object return
459+
result = await core_client.call_utility_async(
460+
"echo_dc_dict", "testarg3", False)
461+
assert isinstance(result,
462+
MyDataclass) and result.message == "testarg3"
463+
464+
# Test dict return with custom value types
465+
result = await core_client.call_utility_async(
466+
"echo_dc_dict", "testarg3", True)
467+
assert isinstance(result, dict) and len(result) == 3
468+
for key, val in result.items():
469+
assert key in ["key1", "key2", "key3"]
470+
assert isinstance(val,
471+
MyDataclass) and val.message == "testarg3"
472+
473+
# Test returning dict with None values
474+
result = await core_client.call_utility_async(
475+
"echo_dc_dict", None, True)
476+
assert isinstance(result, dict) and len(result) == 3
477+
for key, val in result.items():
478+
assert key in ["key1", "key2", "key3"]
479+
assert val is None
480+
481+
finally:
482+
client.shutdown()
483+
484+
485+
@pytest.mark.asyncio(loop_scope="function")
486+
async def test_engine_core_client_util_method_nested_structures(
487+
monkeypatch: pytest.MonkeyPatch):
488+
489+
with monkeypatch.context() as m:
490+
m.setenv("VLLM_USE_V1", "1")
491+
492+
# Must set insecure serialization to allow returning custom types.
493+
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
494+
495+
# Monkey-patch core engine utility function to test.
496+
m.setattr(EngineCore, "echo_dc_nested", echo_dc_nested, raising=False)
497+
498+
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
499+
vllm_config = engine_args.create_engine_config(
500+
usage_context=UsageContext.UNKNOWN_CONTEXT)
501+
executor_class = Executor.get_class(vllm_config)
502+
503+
with set_default_torch_num_threads(1):
504+
client = EngineCoreClient.make_client(
505+
multiprocess_mode=True,
506+
asyncio_mode=True,
507+
vllm_config=vllm_config,
508+
executor_class=executor_class,
509+
log_stats=True,
510+
)
511+
512+
try:
513+
core_client: AsyncMPClient = client
514+
515+
# Test list of dicts: [{"a": val, "b": val}, {"c": val, "d": val}]
516+
result = await core_client.call_utility_async(
517+
"echo_dc_nested", "nested1", "list_of_dicts")
518+
assert isinstance(result, list) and len(result) == 2
519+
for i, item in enumerate(result):
520+
assert isinstance(item, dict)
521+
if i == 0:
522+
assert "a" in item and "b" in item
523+
assert isinstance(
524+
item["a"],
525+
MyDataclass) and item["a"].message == "nested1"
526+
assert isinstance(
527+
item["b"],
528+
MyDataclass) and item["b"].message == "nested1"
529+
else:
530+
assert "c" in item and "d" in item
531+
assert isinstance(
532+
item["c"],
533+
MyDataclass) and item["c"].message == "nested1"
534+
assert isinstance(
535+
item["d"],
536+
MyDataclass) and item["d"].message == "nested1"
537+
538+
# Test dict of lists: {"list1": [val, val], "list2": [val, val]}
539+
result = await core_client.call_utility_async(
540+
"echo_dc_nested", "nested2", "dict_of_lists")
541+
assert isinstance(result, dict) and len(result) == 2
542+
assert "list1" in result and "list2" in result
543+
for key, lst in result.items():
544+
assert isinstance(lst, list) and len(lst) == 2
545+
for item in lst:
546+
assert isinstance(
547+
item, MyDataclass) and item.message == "nested2"
548+
549+
# Test deeply nested: {"outer": [{"inner": [val, val]},
550+
# {"inner": [val]}]}
551+
result = await core_client.call_utility_async(
552+
"echo_dc_nested", "nested3", "deep_nested")
553+
assert isinstance(result, dict) and "outer" in result
554+
outer_list = result["outer"]
555+
assert isinstance(outer_list, list) and len(outer_list) == 2
556+
557+
# First dict in outer list should have "inner" with 2 items
558+
inner_dict1 = outer_list[0]
559+
assert isinstance(inner_dict1, dict) and "inner" in inner_dict1
560+
inner_list1 = inner_dict1["inner"]
561+
assert isinstance(inner_list1, list) and len(inner_list1) == 2
562+
for item in inner_list1:
563+
assert isinstance(item,
564+
MyDataclass) and item.message == "nested3"
565+
566+
# Second dict in outer list should have "inner" with 1 item
567+
inner_dict2 = outer_list[1]
568+
assert isinstance(inner_dict2, dict) and "inner" in inner_dict2
569+
inner_list2 = inner_dict2["inner"]
570+
assert isinstance(inner_list2, list) and len(inner_list2) == 1
571+
assert isinstance(
572+
inner_list2[0],
573+
MyDataclass) and inner_list2[0].message == "nested3"
574+
575+
# Test with None values in nested structures
576+
result = await core_client.call_utility_async(
577+
"echo_dc_nested", None, "list_of_dicts")
578+
assert isinstance(result, list) and len(result) == 2
579+
for item in result:
580+
assert isinstance(item, dict)
581+
for val in item.values():
582+
assert val is None
583+
584+
finally:
585+
client.shutdown()
586+
587+
387588
@pytest.mark.parametrize(
388589
"multiprocessing_mode,publisher_config",
389590
[(True, "tcp"), (False, "inproc")],

vllm/v1/serial_utils.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections.abc import Sequence
88
from inspect import isclass
99
from types import FunctionType
10-
from typing import Any, Optional, Union
10+
from typing import Any, Callable, Optional, Union
1111

1212
import cloudpickle
1313
import msgspec
@@ -59,6 +59,42 @@ def _typestr(val: Any) -> Optional[tuple[str, str]]:
5959
return t.__module__, t.__qualname__
6060

6161

62+
def _encode_type_info_recursive(obj: Any) -> Any:
63+
"""Recursively encode type information for nested structures of
64+
lists/dicts."""
65+
if obj is None:
66+
return None
67+
if type(obj) is list:
68+
return [_encode_type_info_recursive(item) for item in obj]
69+
if type(obj) is dict:
70+
return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
71+
return _typestr(obj)
72+
73+
74+
def _decode_type_info_recursive(
75+
type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any],
76+
Any]) -> Any:
77+
"""Recursively decode type information for nested structures of
78+
lists/dicts."""
79+
if type_info is None:
80+
return data
81+
if isinstance(type_info, dict):
82+
assert isinstance(data, dict)
83+
return {
84+
k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
85+
for k in type_info
86+
}
87+
if isinstance(type_info, list) and (
88+
# Exclude serialized tensors/numpy arrays.
89+
len(type_info) != 2 or not isinstance(type_info[0], str)):
90+
assert isinstance(data, list)
91+
return [
92+
_decode_type_info_recursive(ti, d, convert_fn)
93+
for ti, d in zip(type_info, data)
94+
]
95+
return convert_fn(type_info, data)
96+
97+
6298
class MsgpackEncoder:
6399
"""Encoder with custom torch tensor and numpy array serialization.
64100
@@ -129,12 +165,10 @@ def enc_hook(self, obj: Any) -> Any:
129165
result = obj.result
130166
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
131167
return None, result
132-
# Since utility results are not strongly typed, we also encode
133-
# the type (or a list of types in the case it's a list) to
134-
# help with correct msgspec deserialization.
135-
return _typestr(result) if type(result) is not list else [
136-
_typestr(v) for v in result
137-
], result
168+
# Since utility results are not strongly typed, we recursively
169+
# encode type information for nested structures of lists/dicts
170+
# to help with correct msgspec deserialization.
171+
return _encode_type_info_recursive(result), result
138172

139173
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
140174
raise TypeError(f"Object of type {type(obj)} is not serializable"
@@ -288,15 +322,9 @@ def _decode_utility_result(self, obj: Any) -> UtilityResult:
288322
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
289323
raise TypeError("VLLM_ALLOW_INSECURE_SERIALIZATION must "
290324
"be set to use custom utility result types")
291-
assert isinstance(result_type, list)
292-
if len(result_type) == 2 and isinstance(result_type[0], str):
293-
result = self._convert_result(result_type, result)
294-
else:
295-
assert isinstance(result, list)
296-
result = [
297-
self._convert_result(rt, r)
298-
for rt, r in zip(result_type, result)
299-
]
325+
# Use recursive decoding to handle nested structures
326+
result = _decode_type_info_recursive(result_type, result,
327+
self._convert_result)
300328
return UtilityResult(result)
301329

302330
def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:

0 commit comments

Comments
 (0)