Skip to content

Commit 46d9ec2

Browse files
22quinnepwalsh
authored andcommitted
[Core] Support custom executor qualname (vllm-project#23314)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
1 parent e31b3c5 commit 46d9ec2

File tree

6 files changed

+133
-8
lines changed

6 files changed

+133
-8
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ steps:
244244
- pytest -v -s v1/core
245245
- pytest -v -s v1/engine
246246
- pytest -v -s v1/entrypoints
247+
- pytest -v -s v1/executor
247248
- pytest -v -s v1/sample
248249
- pytest -v -s v1/logits_processors
249250
- pytest -v -s v1/worker

tests/v1/executor/__init__.py

Whitespace-only changes.

tests/v1/executor/test_executor.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import asyncio
5+
import os
6+
from typing import Any, Callable, Optional, Union
7+
8+
import pytest
9+
10+
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
11+
from vllm.sampling_params import SamplingParams
12+
from vllm.v1.engine.async_llm import AsyncLLM
13+
from vllm.v1.engine.llm_engine import LLMEngine
14+
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
15+
16+
17+
class Mock:
18+
...
19+
20+
21+
class CustomMultiprocExecutor(MultiprocExecutor):
22+
23+
def collective_rpc(self,
24+
method: Union[str, Callable],
25+
timeout: Optional[float] = None,
26+
args: tuple = (),
27+
kwargs: Optional[dict] = None,
28+
non_block: bool = False,
29+
unique_reply_rank: Optional[int] = None) -> list[Any]:
30+
# Drop marker to show that this was ran
31+
with open(".marker", "w"):
32+
...
33+
return super().collective_rpc(method, timeout, args, kwargs)
34+
35+
36+
CustomMultiprocExecutorAsync = CustomMultiprocExecutor
37+
MODEL = "Qwen/Qwen3-0.6B"
38+
39+
40+
def test_custom_executor_type_checking():
41+
with pytest.raises(ValueError):
42+
engine_args = EngineArgs(
43+
model=MODEL,
44+
gpu_memory_utilization=0.2,
45+
max_model_len=8192,
46+
distributed_executor_backend=Mock,
47+
)
48+
LLMEngine.from_engine_args(engine_args)
49+
with pytest.raises(ValueError):
50+
engine_args = AsyncEngineArgs(model=MODEL,
51+
gpu_memory_utilization=0.2,
52+
max_model_len=8192,
53+
distributed_executor_backend=Mock)
54+
AsyncLLM.from_engine_args(engine_args)
55+
56+
57+
@pytest.mark.parametrize("distributed_executor_backend", [
58+
CustomMultiprocExecutor,
59+
"tests.v1.executor.test_executor.CustomMultiprocExecutor"
60+
])
61+
def test_custom_executor(distributed_executor_backend, tmp_path):
62+
cwd = os.path.abspath(".")
63+
os.chdir(tmp_path)
64+
try:
65+
assert not os.path.exists(".marker")
66+
67+
engine_args = EngineArgs(
68+
model=MODEL,
69+
gpu_memory_utilization=0.2,
70+
max_model_len=8192,
71+
distributed_executor_backend=distributed_executor_backend,
72+
enforce_eager=True, # reduce test time
73+
)
74+
engine = LLMEngine.from_engine_args(engine_args)
75+
sampling_params = SamplingParams(max_tokens=1)
76+
77+
engine.add_request("0", "foo", sampling_params)
78+
engine.step()
79+
80+
assert os.path.exists(".marker")
81+
finally:
82+
os.chdir(cwd)
83+
84+
85+
@pytest.mark.parametrize("distributed_executor_backend", [
86+
CustomMultiprocExecutorAsync,
87+
"tests.v1.executor.test_executor.CustomMultiprocExecutorAsync"
88+
])
89+
def test_custom_executor_async(distributed_executor_backend, tmp_path):
90+
cwd = os.path.abspath(".")
91+
os.chdir(tmp_path)
92+
try:
93+
assert not os.path.exists(".marker")
94+
95+
engine_args = AsyncEngineArgs(
96+
model=MODEL,
97+
gpu_memory_utilization=0.2,
98+
max_model_len=8192,
99+
distributed_executor_backend=distributed_executor_backend,
100+
enforce_eager=True, # reduce test time
101+
)
102+
engine = AsyncLLM.from_engine_args(engine_args)
103+
sampling_params = SamplingParams(max_tokens=1)
104+
105+
async def t():
106+
stream = engine.generate(request_id="0",
107+
prompt="foo",
108+
sampling_params=sampling_params)
109+
async for x in stream:
110+
...
111+
112+
asyncio.run(t())
113+
114+
assert os.path.exists(".marker")
115+
finally:
116+
os.chdir(cwd)

vllm/config/parallel.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ class ParallelConfig:
143143
placement_group: Optional[PlacementGroup] = None
144144
"""ray distributed model workers placement group."""
145145

146-
distributed_executor_backend: Optional[Union[DistributedExecutorBackend,
146+
distributed_executor_backend: Optional[Union[str,
147+
DistributedExecutorBackend,
147148
type[ExecutorBase]]] = None
148149
"""Backend to use for distributed model
149150
workers, either "ray" or "mp" (multiprocessing). If the product
@@ -416,23 +417,22 @@ def __post_init__(self) -> None:
416417
def use_ray(self) -> bool:
417418
return self.distributed_executor_backend == "ray" or (
418419
isinstance(self.distributed_executor_backend, type)
419-
and self.distributed_executor_backend.uses_ray)
420+
and getattr(self.distributed_executor_backend, "uses_ray", False))
420421

421422
@model_validator(mode='after')
422423
def _verify_args(self) -> Self:
423424
# Lazy import to avoid circular import
424425
from vllm.executor.executor_base import ExecutorBase
425426
from vllm.platforms import current_platform
426-
if self.distributed_executor_backend not in (
427-
"ray", "mp", "uni",
428-
"external_launcher", None) and not (isinstance(
427+
if self.distributed_executor_backend is not None and not isinstance(
428+
self.distributed_executor_backend, str) and not (isinstance(
429429
self.distributed_executor_backend, type) and issubclass(
430430
self.distributed_executor_backend, ExecutorBase)):
431431
raise ValueError(
432432
"Unrecognized distributed executor backend "
433433
f"{self.distributed_executor_backend}. Supported "
434-
"values are 'ray', 'mp' 'uni', 'external_launcher' or"
435-
" custom ExecutorBase subclass.")
434+
"values are 'ray', 'mp' 'uni', 'external_launcher', "
435+
" custom ExecutorBase subclass or its import path.")
436436
if self.use_ray:
437437
from vllm.executor import ray_utils
438438
ray_utils.assert_ray_available()

vllm/engine/arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ class EngineArgs:
290290
# is intended for expert use only. The API may change without
291291
# notice.
292292
distributed_executor_backend: Optional[Union[
293-
DistributedExecutorBackend,
293+
str, DistributedExecutorBackend,
294294
Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
295295
# number of P/D disaggregation (or other disaggregation) workers
296296
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size

vllm/v1/executor/abstract.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
1414
from vllm.executor.uniproc_executor import ( # noqa
1515
UniProcExecutor as UniProcExecutorV0)
16+
from vllm.utils import resolve_obj_by_qualname
1617
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
1718
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
1819

@@ -50,6 +51,13 @@ def get_class(vllm_config: VllmConfig) -> type["Executor"]:
5051
# TODO: make v1 scheduling deterministic
5152
# to support external launcher
5253
executor_class = ExecutorWithExternalLauncher
54+
elif isinstance(distributed_executor_backend, str):
55+
executor_class = resolve_obj_by_qualname(
56+
distributed_executor_backend)
57+
if not issubclass(executor_class, ExecutorBase):
58+
raise TypeError(
59+
"distributed_executor_backend must be a subclass of "
60+
f"ExecutorBase. Got {executor_class}.")
5361
else:
5462
raise ValueError("Unknown distributed executor backend: "
5563
f"{distributed_executor_backend}")

0 commit comments

Comments
 (0)