Skip to content

Commit

Permalink
[Core] Add MultiprocessingGPUExecutor (vllm-project#4539)
Browse files Browse the repository at this point in the history
Co-authored-by: SAHIL SUNEJA <suneja@us.ibm.com>
  • Loading branch information
njhill and sahilsuneja1 authored May 14, 2024
1 parent de43c0b commit 623495f
Show file tree
Hide file tree
Showing 11 changed files with 225 additions and 39 deletions.
12 changes: 8 additions & 4 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@ steps:
mirror_hardwares: [amd]
commands:
- pytest -v -s distributed/test_pynccl_library.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py

- label: Distributed Tests (Multiple Groups)
working_dir: "/vllm-workspace/tests"
Expand Down
17 changes: 10 additions & 7 deletions tests/distributed/test_basic_distributed_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
MODELS = [
os.environ["TEST_DIST_MODEL"],
]
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"


Expand All @@ -36,19 +37,21 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:
enforce_eager = False
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)

backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
if backend_by_env_var == "FLASHINFER":
enforce_eager = True
enforce_eager = backend_by_env_var == "FLASHINFER"

hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model

vllm_model = vllm_runner(model,
dtype=dtype,
tensor_parallel_size=2,
enforce_eager=enforce_eager)
vllm_model = vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
enforce_eager=enforce_eager,
distributed_executor_backend=distributed_executor_backend)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model

Expand Down
4 changes: 4 additions & 0 deletions tests/distributed/test_chunked_prefill_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MODELS = [
os.environ["TEST_DIST_MODEL"],
]
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"


@pytest.mark.skipif(torch.cuda.device_count() < 2,
Expand All @@ -36,6 +37,8 @@ def test_models(
max_tokens: int,
chunked_prefill_token_size: int,
) -> None:
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)

# Add a chunked prefill config.
max_num_seqs = min(chunked_prefill_token_size, 256)
assert chunked_prefill_token_size != -1
Expand All @@ -53,6 +56,7 @@ def test_models(
max_num_seqs=max_num_seqs,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
Expand Down
3 changes: 1 addition & 2 deletions tests/lora/test_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=tp_size,
worker_use_ray=True)
tensor_parallel_size=tp_size)

expected_lora_output = [
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501
Expand Down
38 changes: 29 additions & 9 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,9 +521,7 @@ class ParallelConfig:
Args:
pipeline_parallel_size: Number of pipeline parallel groups.
tensor_parallel_size: Number of tensor parallel groups.
worker_use_ray: Whether to use Ray for model workers. Will be set to
True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1.
worker_use_ray: Deprecated, use distributed_executor_backend instead.
max_parallel_loading_workers: Maximum number of multiple batches
when load model sequentially. To avoid RAM OOM when using tensor
parallel and large models.
Expand All @@ -533,37 +531,57 @@ class ParallelConfig:
If None, will use synchronous tokenization.
ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
distributed_executor_backend: Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If either
pipeline_parallel_size or tensor_parallel_size is greater than 1,
will default to "ray" if Ray is installed or "mp" otherwise.
"""

def __init__(
self,
pipeline_parallel_size: int,
tensor_parallel_size: int,
worker_use_ray: bool,
worker_use_ray: Optional[bool] = None,
max_parallel_loading_workers: Optional[int] = None,
disable_custom_all_reduce: bool = False,
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
ray_workers_use_nsight: bool = False,
placement_group: Optional["PlacementGroup"] = None,
distributed_executor_backend: Optional[str] = None,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray
self.distributed_executor_backend = distributed_executor_backend
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
self.tokenizer_pool_config = tokenizer_pool_config
self.ray_workers_use_nsight = ray_workers_use_nsight
self.placement_group = placement_group

self.world_size = pipeline_parallel_size * self.tensor_parallel_size
if self.world_size > 1:
self.worker_use_ray = True
if worker_use_ray:
if self.distributed_executor_backend is None:
self.distributed_executor_backend = "ray"
elif self.distributed_executor_backend != "ray":
raise ValueError(f"worker-use-ray can't be used with "
f"distributed executor backend "
f"'{self.distributed_executor_backend}'.")

if self.distributed_executor_backend is None and self.world_size > 1:
from vllm.executor import ray_utils
ray_found = ray_utils.ray is not None
self.distributed_executor_backend = "ray" if ray_found else "mp"

self._verify_args()

def _verify_args(self) -> None:
if self.pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is not supported yet.")
if self.distributed_executor_backend not in ("ray", "mp", None):
raise ValueError(
"Unrecognized distributed executor backend. Supported values "
"are 'ray' or 'mp'.")
if not self.disable_custom_all_reduce and self.world_size > 1:
if is_hip():
self.disable_custom_all_reduce = True
Expand All @@ -575,7 +593,8 @@ def _verify_args(self) -> None:
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism.")
if self.ray_workers_use_nsight and not self.worker_use_ray:
if self.ray_workers_use_nsight and (
not self.distributed_executor_backend == "ray"):
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")

Expand Down Expand Up @@ -887,7 +906,8 @@ def create_draft_parallel_config(
pipeline_parallel_size=target_parallel_config.
pipeline_parallel_size,
tensor_parallel_size=target_parallel_config.tensor_parallel_size,
worker_use_ray=target_parallel_config.worker_use_ray,
distributed_executor_backend=target_parallel_config.
distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config.
max_parallel_loading_workers,
disable_custom_all_reduce=target_parallel_config.
Expand Down
16 changes: 12 additions & 4 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class EngineArgs:
seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False
distributed_executor_backend: Optional[str] = None
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
Expand Down Expand Up @@ -221,10 +222,17 @@ def add_cli_args(
' Can be overridden per request via guided_decoding_backend'
' parameter.')
# Parallel arguments
parser.add_argument('--worker-use-ray',
action='store_true',
help='Use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU.')
parser.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp'],
default=EngineArgs.distributed_executor_backend,
help='Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.')
parser.add_argument(
'--worker-use-ray',
action='store_true',
help='Deprecated, use --distributed-executor-backend=ray.')
parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
Expand Down
16 changes: 10 additions & 6 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,27 +348,31 @@ def from_engine_args(
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)

if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "cpu":
assert not engine_config.parallel_config.worker_use_ray, (
"Ray is not supported with the CPU backend.")
assert distributed_executor_backend is None, (
"Distributed execution is not supported with the CPU backend.")
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.parallel_config.worker_use_ray:
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutorAsync)
executor_class = MultiprocessingGPUExecutorAsync
else:
assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
# Create the async LLM engine.
engine = cls(
engine_config.parallel_config.worker_use_ray,
distributed_executor_backend == "ray",
engine_args.engine_use_ray,
**engine_config.to_dict(),
executor_class=executor_class,
Expand Down
10 changes: 7 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def from_engine_args(
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)

# Initialize the cluster and specify the executor class.
if engine_config.device_config.device_type == "neuron":
Expand All @@ -282,13 +284,15 @@ def from_engine_args(
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif engine_config.parallel_config.worker_use_ray:
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor)
executor_class = MultiprocessingGPUExecutor
else:
assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor

Expand Down
Loading

0 comments on commit 623495f

Please sign in to comment.