From 73b246d2941d8147237b3bbcd91648790aa4ee9a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 11 Sep 2024 18:23:36 -0700 Subject: [PATCH] [misc] remove engine_use_ray (#8126) --- tests/async_engine/test_api_server.py | 18 +- tests/async_engine/test_async_llm_engine.py | 14 +- ...i_server_ray.py => test_openapi_server.py} | 7 +- vllm/engine/arg_utils.py | 11 -- vllm/engine/async_llm_engine.py | 163 +++--------------- vllm/engine/llm_engine.py | 6 +- vllm/entrypoints/openai/run_batch.py | 1 - vllm/envs.py | 9 - 8 files changed, 32 insertions(+), 197 deletions(-) rename tests/async_engine/{test_openapi_server_ray.py => test_openapi_server.py} (91%) diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index a89fa445bf96a..83c71b5cf6eb7 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -1,4 +1,3 @@ -import os import subprocess import sys import time @@ -26,8 +25,7 @@ def _query_server_long(prompt: str) -> dict: @pytest.fixture -def api_server(tokenizer_pool_size: int, engine_use_ray: bool, - worker_use_ray: bool): +def api_server(tokenizer_pool_size: int, worker_use_ray: bool): script_path = Path(__file__).parent.joinpath( "api_server_async_engine.py").absolute() commands = [ @@ -37,25 +35,17 @@ def api_server(tokenizer_pool_size: int, engine_use_ray: bool, str(tokenizer_pool_size) ] - # Copy the environment variables and append `VLLM_ALLOW_ENGINE_USE_RAY=1` - # to prevent `--engine-use-ray` raises an exception due to it deprecation - env_vars = os.environ.copy() - env_vars["VLLM_ALLOW_ENGINE_USE_RAY"] = "1" - - if engine_use_ray: - commands.append("--engine-use-ray") if worker_use_ray: commands.append("--worker-use-ray") - uvicorn_process = subprocess.Popen(commands, env=env_vars) + uvicorn_process = subprocess.Popen(commands) yield uvicorn_process.terminate() @pytest.mark.parametrize("tokenizer_pool_size", [0, 2]) @pytest.mark.parametrize("worker_use_ray", [False, True]) -@pytest.mark.parametrize("engine_use_ray", [False, True]) -def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool, - engine_use_ray: bool): +def test_api_server(api_server, tokenizer_pool_size: int, + worker_use_ray: bool): """ Run the API server and test it. diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 03494581431d4..3bf11fbcfb3b8 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -1,5 +1,4 @@ import asyncio -import os from asyncio import CancelledError from dataclasses import dataclass from typing import Optional @@ -72,14 +71,12 @@ def has_unfinished_requests_for_virtual_engine(self, virtual_engine): class MockAsyncLLMEngine(AsyncLLMEngine): - - def _init_engine(self, *args, **kwargs): - return MockEngine() + _engine_class = MockEngine @pytest.mark.asyncio async def test_new_requests_event(): - engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False) + engine = MockAsyncLLMEngine(worker_use_ray=False) engine.start_background_loop() await asyncio.sleep(0.01) assert engine.engine.step_calls == 0 @@ -112,16 +109,11 @@ async def test_new_requests_event(): assert engine.engine.add_request_calls == 3 assert engine.engine.step_calls == old_step_calls + 1 - # Allow deprecated engine_use_ray to not raise exception - os.environ["VLLM_ALLOW_ENGINE_USE_RAY"] = "1" - - engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True) + engine = MockAsyncLLMEngine(worker_use_ray=True) assert engine.get_model_config() is not None assert engine.get_tokenizer() is not None assert engine.get_decoding_config() is not None - os.environ.pop("VLLM_ALLOW_ENGINE_USE_RAY") - def start_engine(): wait_for_gpu_memory_to_clear( diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server.py similarity index 91% rename from tests/async_engine/test_openapi_server_ray.py rename to tests/async_engine/test_openapi_server.py index f70118546c7b6..9e5c7c04287eb 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server.py @@ -19,16 +19,11 @@ def server(): "--max-model-len", "2048", "--enforce-eager", - "--engine-use-ray", "--chat-template", str(chatml_jinja_path), ] - # Allow `--engine-use-ray`, otherwise the launch of the server throw - # an error due to try to use a deprecated feature - env_dict = {"VLLM_ALLOW_ENGINE_USE_RAY": "1"} - with RemoteOpenAIServer(MODEL_NAME, args, - env_dict=env_dict) as remote_server: + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3cd26f6770ed3..6f58c39162087 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1035,7 +1035,6 @@ def create_engine_config(self) -> EngineConfig: @dataclass class AsyncEngineArgs(EngineArgs): """Arguments for asynchronous vLLM engine.""" - engine_use_ray: bool = False disable_log_requests: bool = False @staticmethod @@ -1043,16 +1042,6 @@ def add_cli_args(parser: FlexibleArgumentParser, async_args_only: bool = False) -> FlexibleArgumentParser: if not async_args_only: parser = EngineArgs.add_cli_args(parser) - parser.add_argument('--engine-use-ray', - action='store_true', - help='Use Ray to start the LLM engine in a ' - 'separate process as the server process.' - '(DEPRECATED. This argument is deprecated ' - 'and will be removed in a future update. ' - 'Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to force ' - 'use it. See ' - 'https://github.com/vllm-project/vllm/issues/7045.' - ')') parser.add_argument('--disable-log-requests', action='store_true', help='Disable logging requests.') diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 6ed1a6bba08ea..362b0f3a44b02 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -16,7 +16,7 @@ PromptComponents, SchedulerOutputState) from vllm.engine.metrics_types import StatLoggerBase from vllm.executor.executor_base import ExecutorAsyncBase -from vllm.executor.ray_utils import initialize_ray_cluster, ray +from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, SingletonPromptInputs) from vllm.inputs.parse import is_explicit_encoder_decoder_prompt @@ -30,7 +30,6 @@ from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import print_warning_once logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -590,9 +589,6 @@ class AsyncLLMEngine: worker_use_ray: Whether to use Ray for model workers. Required for distributed execution. Should be the same as `parallel_config.worker_use_ray`. - engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the - async frontend will be executed in a separate process as the - model workers. log_requests: Whether to log the requests. start_engine_loop: If True, the background task to run the engine will be automatically started in the generate call. @@ -604,41 +600,23 @@ class AsyncLLMEngine: def __init__(self, worker_use_ray: bool, - engine_use_ray: bool, *args, log_requests: bool = True, start_engine_loop: bool = True, **kwargs) -> None: self.worker_use_ray = worker_use_ray - self.engine_use_ray = engine_use_ray self.log_requests = log_requests - self.engine = self._init_engine(*args, **kwargs) + self.engine = self._engine_class(*args, **kwargs) # This ensures quick processing of request outputs # so the append to asyncio queues is not delayed, # especially for multi-step. # - # TODO: Currently, disabled for engine_use_ray, ask - # Cody/Will/Woosuk about this case. - self.use_process_request_outputs_callback = not self.engine_use_ray + self.use_process_request_outputs_callback = True if self.use_process_request_outputs_callback: self.engine.process_request_outputs_callback = \ self.process_request_outputs - if self.engine_use_ray: - print_warning_once( - "DEPRECATED. `--engine-use-ray` is deprecated and will " - "be removed in a future update. " - "See https://github.com/vllm-project/vllm/issues/7045.") - - if envs.VLLM_ALLOW_ENGINE_USE_RAY: - print_warning_once( - "VLLM_ALLOW_ENGINE_USE_RAY is set, force engine use Ray") - else: - raise ValueError("`--engine-use-ray` is deprecated. " - "Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to " - "force use it") - self.background_loop: Optional[asyncio.Future] = None # We need to keep a reference to unshielded # task as well to prevent it from being garbage @@ -725,16 +703,11 @@ def from_engine_args( # Create the engine configs. engine_config = engine_args.create_engine_config() - if engine_args.engine_use_ray: - from vllm.executor import ray_utils - ray_utils.assert_ray_available() - executor_class = cls._get_executor_cls(engine_config) # Create the async LLM engine. engine = cls( executor_class.uses_ray, - engine_args.engine_use_ray, **engine_config.to_dict(), executor_class=executor_class, log_requests=not engine_args.disable_log_requests, @@ -777,10 +750,6 @@ async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, ) -> AnyTokenizer: - if self.engine_use_ray: - return await self.engine.get_tokenizer.remote( # type: ignore - lora_request) - return await (self.engine.get_tokenizer_group(). get_lora_tokenizer_async(lora_request)) @@ -814,26 +783,6 @@ def shutdown_background_loop(self) -> None: self._background_loop_unshielded = None self.background_loop = None - def _init_engine(self, *args, - **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]: - if not self.engine_use_ray: - engine_class = self._engine_class - elif self.worker_use_ray: - engine_class = ray.remote(num_cpus=0)(self._engine_class).remote - else: - # FIXME(woosuk): This is a bit hacky. Be careful when changing the - # order of the arguments. - cache_config = kwargs["cache_config"] - parallel_config = kwargs["parallel_config"] - if (parallel_config.tensor_parallel_size == 1 - and parallel_config.pipeline_parallel_size == 1): - num_gpus = cache_config.gpu_memory_utilization - else: - num_gpus = 1 - engine_class = ray.remote(num_gpus=num_gpus)( - self._engine_class).remote - return engine_class(*args, **kwargs) - async def engine_step(self, virtual_engine: int) -> bool: """Kick the engine to process the waiting requests. @@ -844,13 +793,8 @@ async def engine_step(self, virtual_engine: int) -> bool: for new_request in new_requests: # Add the request into the vLLM engine's waiting queue. - # TODO: Maybe add add_request_batch to reduce Ray overhead try: - if self.engine_use_ray: - await self.engine.add_request.remote( # type: ignore - **new_request) - else: - await self.engine.add_request_async(**new_request) + await self.engine.add_request_async(**new_request) except ValueError as e: # TODO: use a vLLM specific error for failed validation self._request_tracker.process_exception( @@ -862,10 +806,7 @@ async def engine_step(self, virtual_engine: int) -> bool: if aborted_requests: await self._engine_abort(aborted_requests) - if self.engine_use_ray: - request_outputs = await self.engine.step.remote() # type: ignore - else: - request_outputs = await self.engine.step_async(virtual_engine) + request_outputs = await self.engine.step_async(virtual_engine) # Put the outputs into the corresponding streams. # If used as a callback, then already invoked inside @@ -891,16 +832,10 @@ def process_request_outputs(self, request_outputs) -> bool: return all_finished async def _engine_abort(self, request_ids: Iterable[str]): - if self.engine_use_ray: - await self.engine.abort_request.remote(request_ids) # type: ignore - else: - self.engine.abort_request(request_ids) + self.engine.abort_request(request_ids) async def run_engine_loop(self): - if self.engine_use_ray: - pipeline_parallel_size = 1 # type: ignore - else: - pipeline_parallel_size = \ + pipeline_parallel_size = \ self.engine.parallel_config.pipeline_parallel_size has_requests_in_progress = [False] * pipeline_parallel_size while True: @@ -912,12 +847,7 @@ async def run_engine_loop(self): # timeout, and unblocks the RPC thread in the workers so that # they can process any other queued control plane messages, # such as add/remove lora adapters. - if self.engine_use_ray: - await (self.engine.stop_remote_worker_execution_loop. - remote() # type: ignore - ) - else: - await self.engine.stop_remote_worker_execution_loop_async() + await self.engine.stop_remote_worker_execution_loop_async() await self._request_tracker.wait_for_new_requests() logger.debug("Got new requests!") requests_in_progress = [ @@ -938,17 +868,9 @@ async def run_engine_loop(self): for task in done: result = task.result() virtual_engine = requests_in_progress.index(task) - if self.engine_use_ray: - has_unfinished_requests = ( - await (self.engine. - has_unfinished_requests_for_virtual_engine. - remote( # type: ignore - virtual_engine))) - else: - has_unfinished_requests = ( - self.engine. - has_unfinished_requests_for_virtual_engine( - virtual_engine)) + has_unfinished_requests = ( + self.engine.has_unfinished_requests_for_virtual_engine( + virtual_engine)) if result or has_unfinished_requests: requests_in_progress[virtual_engine] = ( asyncio.create_task( @@ -1190,52 +1112,29 @@ def _abort(self, request_id: str) -> None: async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" - if self.engine_use_ray: - return await self.engine.get_model_config.remote() # type: ignore - else: - return self.engine.get_model_config() + return self.engine.get_model_config() async def get_parallel_config(self) -> ParallelConfig: """Get the parallel configuration of the vLLM engine.""" - if self.engine_use_ray: - return await self.engine.get_parallel_config.remote( # type: ignore - ) - else: - return self.engine.get_parallel_config() + return self.engine.get_parallel_config() async def get_decoding_config(self) -> DecodingConfig: """Get the decoding configuration of the vLLM engine.""" - if self.engine_use_ray: - return await self.engine.get_decoding_config.remote( # type: ignore - ) - else: - return self.engine.get_decoding_config() + return self.engine.get_decoding_config() async def get_scheduler_config(self) -> SchedulerConfig: """Get the scheduling configuration of the vLLM engine.""" - if self.engine_use_ray: - return await self.engine.get_scheduler_config.remote( # type: ignore - ) - else: - return self.engine.get_scheduler_config() + return self.engine.get_scheduler_config() async def get_lora_config(self) -> LoRAConfig: """Get the lora configuration of the vLLM engine.""" - if self.engine_use_ray: - return await self.engine.get_lora_config.remote( # type: ignore - ) - else: - return self.engine.get_lora_config() + return self.engine.get_lora_config() async def do_log_stats( self, scheduler_outputs: Optional[SchedulerOutputs] = None, model_output: Optional[List[SamplerOutput]] = None) -> None: - if self.engine_use_ray: - await self.engine.do_log_stats.remote( # type: ignore - scheduler_outputs, model_output) - else: - self.engine.do_log_stats() + self.engine.do_log_stats() async def check_health(self) -> None: """Raises an error if engine is unhealthy.""" @@ -1244,37 +1143,17 @@ async def check_health(self) -> None: if self.is_stopped: raise AsyncEngineDeadError("Background loop is stopped.") - if self.engine_use_ray: - try: - await self.engine.check_health.remote() # type: ignore - except ray.exceptions.RayActorError as e: - raise RuntimeError("Engine is dead.") from e - else: - await self.engine.check_health_async() + await self.engine.check_health_async() logger.debug("Health check took %fs", time.perf_counter() - t) async def is_tracing_enabled(self) -> bool: - if self.engine_use_ray: - return await self.engine.is_tracing_enabled.remote( # type: ignore - ) - else: - return self.engine.is_tracing_enabled() + return self.engine.is_tracing_enabled() def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: - if self.engine_use_ray: - ray.get( - self.engine.add_logger.remote( # type: ignore - logger_name=logger_name, logger=logger)) - else: - self.engine.add_logger(logger_name=logger_name, logger=logger) + self.engine.add_logger(logger_name=logger_name, logger=logger) def remove_logger(self, logger_name: str) -> None: - if self.engine_use_ray: - ray.get( - self.engine.remove_logger.remote( # type: ignore - logger_name=logger_name)) - else: - self.engine.remove_logger(logger_name=logger_name) + self.engine.remove_logger(logger_name=logger_name) async def start_profile(self) -> None: self.engine.model_executor._run_workers("start_profile") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 94271c4a93151..92e46c7af5162 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -3,8 +3,8 @@ from collections import deque from contextlib import contextmanager from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List, - Mapping, NamedTuple, Optional) +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, + Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence from typing import Set, Tuple, Type, Union @@ -397,7 +397,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: # Currently used by AsyncLLMEngine to ensure quick append # of request outputs to asyncio queues - self.process_request_outputs_callback = None + self.process_request_outputs_callback: Optional[Callable] = None # Create the scheduler. # NOTE: the cache_config here have been updated with the numbers of diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 94f31d272841f..b745410fe6b3b 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -195,7 +195,6 @@ async def main(args): engine = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) - # When using single vLLM without engine_use_ray model_config = await engine.get_model_config() if args.disable_log_requests: diff --git a/vllm/envs.py b/vllm/envs.py index 9e34b3d08b9e2..b3678399fe207 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -58,7 +58,6 @@ VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000 - VLLM_ALLOW_ENGINE_USE_RAY: bool = False VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False @@ -391,14 +390,6 @@ def get_default_config_root(): "VLLM_RPC_GET_DATA_TIMEOUT_MS": lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")), - # If set, allow running the engine as a separate ray actor, - # which is a deprecated feature soon to be removed. - # See https://github.com/vllm-project/vllm/issues/7045 - "VLLM_ALLOW_ENGINE_USE_RAY": - lambda: - (os.environ.get("VLLM_ALLOW_ENGINE_USE_RAY", "0").strip().lower() in - ("1", "true")), - # a list of plugin names to load, separated by commas. # if this is not set, it means all plugins will be loaded # if this is set to an empty string, no plugins will be loaded