diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6f06099edd53..b5ea4407ef5b 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -46,7 +46,6 @@ steps: mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - - tests/mq_llm_engine - tests/async_engine - tests/test_inputs.py - tests/test_outputs.py @@ -57,7 +56,6 @@ steps: - tests/transformers_utils commands: - python3 standalone_tests/lazy_imports.py - - pytest -v -s mq_llm_engine # MQLLMEngine - pytest -v -s async_engine # AsyncLLMEngine - pytest -v -s test_inputs.py - pytest -v -s test_outputs.py diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index 2bf29ecf087f..e2c83b9c4004 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -10,7 +10,6 @@ import pytest from vllm.config.multimodal import MultiModalConfig -from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_models import (BaseModelPath, @@ -18,6 +17,7 @@ from vllm.lora.request import LoRARequest from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.v1.engine.async_llm import AsyncLLM MODEL_NAME = "openai-community/gpt2" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] @@ -82,7 +82,7 @@ def register_mock_resolver(): @pytest.fixture def mock_serving_setup(): """Provides a mocked engine and serving completion instance.""" - mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 502704c9bbdf..de26fce854f5 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -13,12 +13,12 @@ import pytest_asyncio from vllm.config.multimodal import MultiModalConfig -from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_models import (BaseModelPath, OpenAIServingModels) from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.v1.engine.async_llm import AsyncLLM from ...utils import RemoteOpenAIServer @@ -276,7 +276,7 @@ def test_async_serving_chat_init(): @pytest.mark.asyncio async def test_serving_chat_returns_correct_model_name(): - mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False @@ -312,7 +312,7 @@ async def return_model_name(*args): @pytest.mark.asyncio async def test_serving_chat_should_set_correct_max_tokens(): - mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False @@ -355,7 +355,7 @@ async def test_serving_chat_should_set_correct_max_tokens(): } # Reinitialize the engine with new settings - mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False @@ -410,7 +410,7 @@ async def test_serving_chat_should_set_correct_max_tokens(): } # Reinitialize the engine with new settings - mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False @@ -467,7 +467,7 @@ async def test_serving_chat_could_load_correct_generation_config(): "repetition_penalty": 1.05 } - mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False @@ -523,7 +523,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): mock_model_config = MockModelConfig() mock_model_config.hf_config.model_type = model_type - mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False diff --git a/tests/mq_llm_engine/__init__.py b/tests/mq_llm_engine/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/tests/mq_llm_engine/conftest.py b/tests/mq_llm_engine/conftest.py deleted file mode 100644 index 375b248ebeda..000000000000 --- a/tests/mq_llm_engine/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') diff --git a/tests/mq_llm_engine/test_abort.py b/tests/mq_llm_engine/test_abort.py deleted file mode 100644 index 5ff08cbb3248..000000000000 --- a/tests/mq_llm_engine/test_abort.py +++ /dev/null @@ -1,69 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Test that aborting is handled properly.""" - -import asyncio -import tempfile -import uuid - -import pytest - -from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate -from vllm.engine.arg_utils import AsyncEngineArgs - -MODEL = "google/gemma-1.1-2b-it" -ENGINE_ARGS = AsyncEngineArgs(model=MODEL) -RAISED_ERROR = KeyError -RAISED_VALUE = "foo" -EXPECTED_TOKENS = 250 - - -@pytest.fixture(scope="function") -def tmp_socket(): - with tempfile.TemporaryDirectory() as td: - yield f"ipc://{td}/{uuid.uuid4()}" - - -@pytest.mark.asyncio -async def test_abort(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket) as engine: - - client = await engine.make_client() - - request_id_to_be_aborted = "request-aborted" - request_ids_a = [f"request-a-{idx}" for idx in range(10)] - request_ids_b = [f"request-b-{idx}" for idx in range(10)] - - # Requests started before one to be aborted. - tasks = [] - for request_id in request_ids_a: - tasks.append( - asyncio.create_task( - generate(client, request_id, EXPECTED_TOKENS))) - - # Aborted. - task_aborted = asyncio.create_task( - generate(client, request_id_to_be_aborted, EXPECTED_TOKENS)) - - # Requests started after one to be aborted. - for request_id in request_ids_b: - tasks.append( - asyncio.create_task( - generate(client, request_id, EXPECTED_TOKENS))) - - # Actually abort. - await asyncio.sleep(0.5) - await client.abort(request_id_to_be_aborted) - - # Confirm that we got all the EXPECTED tokens from the requests. - for task in tasks: - count, request_id = await task - assert count == EXPECTED_TOKENS, ( - f"{request_id} generated only {count} tokens") - - # Cancel task (this will hang indefinitely if not). - task_aborted.cancel() - - # Shutdown. - client.close() diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py deleted file mode 100644 index 77e3732cd06c..000000000000 --- a/tests/mq_llm_engine/test_error_handling.py +++ /dev/null @@ -1,376 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Test that various errors are handled properly.""" - -import asyncio -import tempfile -import time -import uuid -from unittest.mock import Mock - -import pytest - -from tests.mq_llm_engine.utils import RemoteMQLLMEngine -from vllm import SamplingParams -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.llm_engine import LLMEngine -from vllm.engine.multiprocessing import MQEngineDeadError -from vllm.engine.multiprocessing.engine import MQLLMEngine -from vllm.entrypoints.openai.api_server import build_async_engine_client -from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.lora.request import LoRARequest -from vllm.sequence import SequenceGroupMetadata -from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser - -MODEL = "google/gemma-1.1-2b-it" -ENGINE_ARGS = AsyncEngineArgs(model=MODEL, enforce_eager=True) -RAISED_ERROR = KeyError -RAISED_VALUE = "foo" - - -@pytest.fixture(scope="function") -def tmp_socket(): - with tempfile.TemporaryDirectory() as td: - yield f"ipc://{td}/{uuid.uuid4()}" - - -def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str): - # Make engine. - engine = MQLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context=UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) - - # Raise error during first forward pass. - engine.engine.model_executor.execute_model = Mock( - side_effect=RAISED_ERROR(RAISED_VALUE)) - - # Run engine. - engine.start() - - -@pytest.mark.asyncio -async def test_evil_forward(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket, - run_fn=run_with_evil_forward) as engine: - - client = await engine.make_client() - - # Server should be healthy after initial probe. - await asyncio.sleep(2.0) - await client.check_health() - - # Throws an error that should get ENGINE_DEAD_ERROR. - with pytest.raises(MQEngineDeadError): - async for _ in client.generate(prompt="Hello my name is", - sampling_params=SamplingParams(), - request_id=str(uuid.uuid4())): - pass - assert client.errored - - await asyncio.sleep(1.0) - with pytest.raises(RAISED_ERROR): - await client.check_health() - assert client.errored - - # Shutdown. - client.close() - - -def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs, - ipc_path: str): - # Make engine. - engine = MQLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context=UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) - - # Raise error during first forward pass. - engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR) - - # Run engine. - engine.start() - - -@pytest.mark.asyncio -async def test_failed_health_check(tmp_socket): - with RemoteMQLLMEngine( - engine_args=ENGINE_ARGS, - ipc_path=tmp_socket, - run_fn=run_with_evil_model_executor_health) as engine: - - client = await engine.make_client() - assert client.is_running - - # Health probe should throw RAISED_ERROR. - await asyncio.sleep(15.) - - with pytest.raises(RAISED_ERROR): - await client.check_health() - assert client.errored - - # Generate call should throw ENGINE_DEAD_ERROR - with pytest.raises(MQEngineDeadError): - async for _ in client.generate(prompt="Hello my name is", - sampling_params=SamplingParams(), - request_id=str(uuid.uuid4())): - pass - - client.close() - - -def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str): - # Make engine. - engine = MQLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context=UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) - - # Raise error during abort call. - engine.engine.abort_request = Mock(side_effect=RAISED_ERROR) - - # Run engine. - engine.start() - - -@pytest.mark.asyncio -async def test_failed_abort(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket, - run_fn=run_with_evil_abort) as engine: - - client = await engine.make_client() - assert client.is_running - - # First check health should work. - await client.check_health() - - # Trigger an abort on the client side. - # This request ID does not exist, and will cause the engine to error - await client.abort(request_id="foo") - - # Future generation requests will now fail - # with reference to the original KeyError("foo") - with pytest.raises(MQEngineDeadError) as execinfo: - async for _ in client.generate( - prompt="Hello my name is", - sampling_params=SamplingParams(max_tokens=10), - request_id=str(uuid.uuid4())): - pass - assert "KeyError" in repr(execinfo.value) - assert client.errored - - # This should raise the original error. - with pytest.raises(RAISED_ERROR): - await client.check_health() - - client.close() - - -@pytest.mark.asyncio -async def test_batch_error(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket, - run_fn=run_with_evil_abort) as engine: - - client = await engine.make_client() - assert client.is_running - - # First check health should work. - await client.check_health() - - # Batch of requests - async def do_generate(client): - # min_tokens=2048 to keep busy the engine busy - # to get enough time to get process a request - # that will crash the engine - params = SamplingParams(min_tokens=2048, max_tokens=2048) - async for _ in client.generate(prompt="Hello my name is", - sampling_params=params, - request_id=str(uuid.uuid4())): - pass - - tasks = [asyncio.create_task(do_generate(client)) for _ in range(10)] - - # This request will force a processing batch to raise - # an exception and next the engine get errored - await client.abort(request_id="foo") - - # The batch of those request failed, then they - # should get the same exception as a MQEngineDeadError. - errors = await asyncio.gather(*tasks, return_exceptions=True) - for e in errors: - assert isinstance(e, MQEngineDeadError) - assert "KeyError" in repr(e) - - client.close() - - -@pytest.mark.asyncio -async def test_bad_request(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket) as engine: - - client = await engine.make_client() - - # Invalid request should fail, but not crash the server. - with pytest.raises(ValueError): - async for _ in client.generate(prompt="Hello my name is", - sampling_params=SamplingParams(), - request_id="abcd-1", - lora_request=LoRARequest( - "invalid-lora", 1, - "invalid-path")): - pass - - # This request should be okay. - async for _ in client.generate(prompt="Hello my name is", - sampling_params=SamplingParams(), - request_id="abcd-2"): - pass - - # Shutdown. - client.close() - - -@pytest.mark.asyncio -async def test_mp_crash_detection(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: - - parser = FlexibleArgumentParser( - description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) - args = parser.parse_args([]) - - # When LLMEngine is loaded, it will crash. - def mock_init(): - raise ValueError - - m.setattr(LLMEngine, "__init__", mock_init) - - start = time.perf_counter() - async with build_async_engine_client(args): - pass - end = time.perf_counter() - - assert end - start < 100, ( - "Expected vLLM to gracefully shutdown in <100s " - "if there is an error in the startup.") - - -@pytest.mark.asyncio -async def test_mp_cuda_init(): - # it should not crash, when cuda is initialized - # in the API server process - import torch - torch.cuda.init() - parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) - args = parser.parse_args([]) - - async with build_async_engine_client(args): - pass - - -@pytest.mark.asyncio -async def test_engine_process_death(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket) as engine: - - client = await engine.make_client() - assert client.is_running - - # kill the engine process - engine.proc.kill() - - # Generate call should fail - with pytest.raises(MQEngineDeadError): - async for _ in client.generate(prompt="Hello my name is", - sampling_params=SamplingParams(), - request_id=str(uuid.uuid4())): - pass - - # And the health check should show the engine is dead - with pytest.raises(RuntimeError, match="Engine process .* died"): - await client.check_health() - - client.close() - - -def run_with_evil_input_processing(engine_args: AsyncEngineArgs, - ipc_path: str): - """Simulate an exception while preparing inputs for the model. - In the wild, this could be something like a multimodal input processor - failing on invalid image data.""" - - # Make engine. - engine = MQLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context=UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) - - runner = engine.engine.model_executor.driver_worker.worker.model_runner - - # Raise error in the model runner when adding a sequence group. - # See class ModelInputForGPUBuilder - def raiser(_, seq_group_metadata: SequenceGroupMetadata): - if seq_group_metadata.request_id.startswith("evil"): - raise RAISED_ERROR(RAISED_VALUE) - - runner.builder.per_seq_group_compute_fns.append(raiser) - - # Run engine. - engine.start() - - -@pytest.mark.asyncio -async def test_failed_inputs(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket, - run_fn=run_with_evil_input_processing) as engine: - - client = await engine.make_client() - assert client.is_running - - # Engine should be healthy - await client.check_health() - - async def run_failing_request(): - async for _ in client.generate( - prompt="Hello my name is", - sampling_params=SamplingParams(max_tokens=10), - request_id="evil" + str(uuid.uuid4())): - pass - - async def run_passing_request(): - async for _ in client.generate( - prompt="Hello my name is", - sampling_params=SamplingParams(max_tokens=10), - request_id=str(uuid.uuid4())): - pass - - passing_tasks = [ - asyncio.create_task(run_passing_request()) for _ in range(10) - ] - failing_tasks = [ - asyncio.create_task(run_failing_request()) for _ in range(10) - ] - await asyncio.gather(*failing_tasks, return_exceptions=True) - await asyncio.gather(*passing_tasks) - - # All the bad inputs should have raised - for task in failing_tasks: - with pytest.raises(RAISED_ERROR): - task.result() - - # But all good inputs should have still succeeded - for task in passing_tasks: - task.result() - - # And the engine should remain healthy - assert not client.errored - await client.check_health() - - client.close() diff --git a/tests/mq_llm_engine/test_load.py b/tests/mq_llm_engine/test_load.py deleted file mode 100644 index c934706611ae..000000000000 --- a/tests/mq_llm_engine/test_load.py +++ /dev/null @@ -1,59 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Test that the MQLLMEngine is able to handle 10k concurrent requests.""" - -import asyncio -import tempfile -import uuid - -import pytest - -from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate -from vllm.engine.arg_utils import AsyncEngineArgs - -MODEL = "google/gemma-1.1-2b-it" -NUM_EXPECTED_TOKENS = 10 -NUM_REQUESTS = 10000 - -# Scenarios to test for num generated token. -ENGINE_ARGS = AsyncEngineArgs(model=MODEL) - - -@pytest.fixture(scope="function") -def tmp_socket(): - with tempfile.TemporaryDirectory() as td: - yield f"ipc://{td}/{uuid.uuid4()}" - - -@pytest.mark.asyncio -async def test_load(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket) as engine: - - client = await engine.make_client() - - request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] - - # Create concurrent requests. - tasks = [] - for request_id in request_ids: - tasks.append( - asyncio.create_task( - generate(client, request_id, NUM_EXPECTED_TOKENS))) - - # Confirm that we got all the EXPECTED tokens from the requests. - failed_request_id = None - tokens = None - for task in tasks: - num_generated_tokens, request_id = await task - if (num_generated_tokens != NUM_EXPECTED_TOKENS - and failed_request_id is None): - failed_request_id = request_id - tokens = num_generated_tokens - - assert failed_request_id is None, ( - f"{failed_request_id} generated {tokens} but " - f"expected {NUM_EXPECTED_TOKENS}") - - # Shutdown. - client.close() diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py deleted file mode 100644 index 7976d5031aea..000000000000 --- a/tests/mq_llm_engine/utils.py +++ /dev/null @@ -1,81 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import multiprocessing -from typing import Callable, Union - -from vllm import SamplingParams -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.multiprocessing.client import MQLLMEngineClient -from vllm.engine.multiprocessing.engine import MQLLMEngine -from vllm.outputs import RequestOutput -from vllm.usage.usage_lib import UsageContext - - -async def generate( - client: MQLLMEngineClient, - request_id: str, - num_tokens: int, - return_output: bool = False) -> Union[RequestOutput, tuple[int, str]]: - - final_output = None - count = 0 - async for out in client.generate( - request_id=request_id, - prompt="Hello my name is Robert and", - sampling_params=SamplingParams(max_tokens=num_tokens, - temperature=0)): - - count += 1 - final_output = out - await asyncio.sleep(0.) - - if return_output: - return final_output - - # Confirm we generated all the tokens we expected. - return count, request_id - - -def run_normal(engine_args: AsyncEngineArgs, ipc_path: str): - # Make engine. - engine = MQLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context=UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) - - # Run engine. - engine.start() - - -class RemoteMQLLMEngine: - - def __init__(self, - engine_args: AsyncEngineArgs, - ipc_path: str, - run_fn: Callable = run_normal) -> None: - - self.engine_args = engine_args - self.ipc_path = ipc_path - context = multiprocessing.get_context("spawn") - self.proc = context.Process(target=run_fn, - args=(engine_args, ipc_path)) - self.proc.start() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.proc.kill() - - async def make_client(self) -> MQLLMEngineClient: - engine_config = self.engine_args.create_engine_config() - client = MQLLMEngineClient(self.ipc_path, engine_config, self.proc.pid) - while True: - try: - await client.setup() - break - except TimeoutError: - assert self.proc.is_alive() - return client diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py deleted file mode 100644 index 9f64ee0808df..000000000000 --- a/vllm/engine/multiprocessing/__init__.py +++ /dev/null @@ -1,145 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import uuid -from dataclasses import dataclass, field -from enum import Enum -from typing import List, Mapping, Optional, Union - -from vllm import PoolingParams -from vllm.inputs import PromptType -from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput -from vllm.sampling_params import SamplingParams -from vllm.utils import Device - -VLLM_RPC_SUCCESS_STR = "SUCCESS" - -IPC_INPUT_EXT = "_input_socket" -IPC_OUTPUT_EXT = "_output_socket" -IPC_HEALTH_EXT = "_health_socket" -IPC_DATA_EXT = "_data_socket" - - -class MQEngineDeadError(RuntimeError): - pass - - -@dataclass -class RPCProcessRequest: - prompt: PromptType - params: Union[SamplingParams, PoolingParams] - request_id: str - lora_request: Optional[LoRARequest] = None - trace_headers: Optional[Mapping[str, str]] = None - priority: int = 0 - - def __init__( - self, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> None: - super().__init__() - - self.prompt = prompt - self.params = params - self.request_id = request_id - self.lora_request = lora_request - self.trace_headers = trace_headers - self.priority = priority - - -@dataclass -class RPCError: - request_id: Optional[str] - is_engine_errored: bool - exception: BaseException - - -@dataclass -class RPCAbortRequest: - request_id: str - - -class RPCStartupRequest(Enum): - IS_SERVER_READY = 1 - - -@dataclass -class RPCStartupResponse: - tracing_enabled: bool - - -class RPCUProfileRequest(Enum): - START_PROFILE = 1 - STOP_PROFILE = 2 - - -class RPCResetMultiModalCacheRequest(Enum): - RESET = 1 - - -@dataclass -class RPCResetPrefixCacheRequest: - device: Device - - -class RPCSleepRequest(Enum): - SLEEP_LEVEL_1 = 1 - SLEEP_LEVEL_2 = 2 - - -@dataclass -class RPCWakeUpRequest: - tags: Optional[list[str]] = None - - -@dataclass -class RPCIsSleepingRequest: - # Set the default value of request_id to a new UUID - request_id: str = field(default_factory=lambda: str(uuid.uuid4())) - - -@dataclass -class RPCIsSleepingResponse: - request_id: str - is_sleeping: bool - - -@dataclass -class RPCLoadAdapterRequest: - lora_request: LoRARequest - # Set the default value of request_id to a new UUID - request_id: str = field(default_factory=lambda: str(uuid.uuid4())) - - -@dataclass -class RPCAdapterLoadedResponse: - request_id: str - lora_loaded: bool - - -RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, - RPCUProfileRequest, RPCLoadAdapterRequest, - RPCResetMultiModalCacheRequest, - RPCResetPrefixCacheRequest, RPCSleepRequest, - RPCWakeUpRequest, RPCIsSleepingRequest] - -REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse, - RPCIsSleepingResponse, RPCError] - - -def ENGINE_DEAD_ERROR( - error: Optional[BaseException] = None) -> MQEngineDeadError: - if error is None: - return MQEngineDeadError( - "Engine loop is not running. Inspect the stacktrace to " - "find the original error") - - return MQEngineDeadError( - "Engine loop is not running. Inspect the stacktrace to " - f"find the original error: {repr(error)}.") diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py deleted file mode 100644 index 7d1f29a9824d..000000000000 --- a/vllm/engine/multiprocessing/client.py +++ /dev/null @@ -1,643 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import copy -import pickle -from contextlib import contextmanager, suppress -from typing import (Any, AsyncGenerator, Dict, Iterable, Iterator, List, - Mapping, Optional, Union) - -import cloudpickle -import psutil -import zmq -import zmq.asyncio -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from vllm import PoolingParams -from vllm.config import DecodingConfig, ModelConfig, VllmConfig -from vllm.core.scheduler import SchedulerOutputs -# yapf conflicts with isort for this block -# yapf: disable -from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, - IPC_HEALTH_EXT, IPC_INPUT_EXT, - IPC_OUTPUT_EXT, RPC_REQUEST_T, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCAdapterLoadedResponse, RPCError, - RPCIsSleepingRequest, - RPCIsSleepingResponse, - RPCLoadAdapterRequest, - RPCProcessRequest, - RPCResetMultiModalCacheRequest, - RPCResetPrefixCacheRequest, - RPCSleepRequest, RPCStartupRequest, - RPCStartupResponse, - RPCUProfileRequest, RPCWakeUpRequest) -from vllm.engine.protocol import EngineClient -# yapf: enable -from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.inputs import PromptType -from vllm.inputs.preprocess import InputPreprocessor -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import PoolingRequestOutput, RequestOutput -from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.utils import Device - -logger = init_logger(__name__) - - -class MQClientClosedError(Exception): - """Exception class raised when the client is used post-close. - - The client can be closed, which closes the ZMQ context. This normally - happens on server shutdown. In some cases, methods like abort and - do_log_stats will still be called and then try to open a socket, which - causes a ZMQError and creates a huge stack trace. - So, we throw this error such that we can suppress it. - """ - - -class MQLLMEngineClient(EngineClient): - """A client wrapper for MQLLMEngine that conforms to the - EngineClient protocol. - - MQLLMEngine and MQLLMEngineClient are intended to run in separate - processes communicating via zeromq ipc sockets. - - The entrypoint to MQLLMEngineClient is through the generate() - method. On generate() MQLLMEngine does three things: - - Creates an asyncio output queue - - Sends a RPCGenerateRequest to the MQLLMEngine via zmq - - Pulls RequestOutputs from its queue and yields them - - MQLLMEngine runs two background loops: - - output_loop: the output loop pulls List[RequestOutput] - from the MQLLMEngine via zmq (each list is the output - of one engine_step in the LLMEngine). It then parses - the list and pushes individual request_outputs into - the corresponding output_queue such that they can be - consumed by the .generate() method. - - health_loop: the health loop queries the health socket - every N seconds, confirming the engine is healthy - """ - - def __init__(self, ipc_path: str, engine_config: VllmConfig, - engine_pid: int): - self.context = zmq.asyncio.Context() - self._errored_with: Optional[BaseException] = None - - # Get the configs. - self.vllm_config = engine_config - self.model_config = engine_config.model_config - self.decoding_config = engine_config.decoding_config - - if self.vllm_config.model_config.skip_tokenizer_init: - self.tokenizer = None - - else: - # Create the tokenizer group. - self.tokenizer = init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=engine_config.scheduler_config, - lora_config=engine_config.lora_config) - - self.input_preprocessor = InputPreprocessor(self.model_config, - self.tokenizer) - - # Send RPCGenerateRequest to the MQLLMEngine. - self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) - self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}") - - # Receive streams of RequestOutput from the MQLLMEngine. - self.output_socket: Socket = self.context.socket(zmq.constants.PULL) - self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") - - # IPC path for acking heartbeats. - self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) - self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") - - # IPC path for the data socket. - self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" - - # Stream for each individual request. - self.output_queues: Dict[str, asyncio.Queue] = {} - - # Loop to handle output of the LLMEngine periodically. - # Started after the MQLLMEngine is ready so that we can - # build the Client in an executor to enable clean shutdown. - self.output_loop: Optional[asyncio.Task] = None - - # Loop to check health of the LLMEngine periodically. - # Started after the MQLLMEngine is ready. - self.health_loop: Optional[asyncio.Task] = None - self._engine_process = psutil.Process(engine_pid) - - @staticmethod - def is_unsupported_config(vllm_config: VllmConfig): - # Pipeline parallel not yet supported - return vllm_config.parallel_config.pipeline_parallel_size > 1 - - @contextmanager - def get_data_socket(self) -> Iterator[Socket]: - socket = self.context.socket(zmq.constants.DEALER) - try: - socket.connect(self.data_ipc_path) - yield socket - finally: - socket.close(linger=0) - - async def run_heartbeat_loop(self, timeout: int): - """Background loop that continually checks to ensure the engine process - is still alive. - """ - try: - while True: - # Check if the engine process is running: - if not self._engine_process.is_running() or ( - self._engine_process.status() == psutil.STATUS_ZOMBIE): - # NB: is_running() returns True for zombies - self._set_errored( - RuntimeError( - f"Engine process (pid {self._engine_process.pid}) " - "died.")) - break - - if await self.heartbeat_socket.poll(timeout=timeout): - # Heartbeat received- check the message - await self._check_success( - error_message="Heartbeat failed.", - socket=self.heartbeat_socket) - - logger.debug("Heartbeat successful.") - - except asyncio.CancelledError: - logger.debug("Shutting down MQLLMEngineClient check health loop.") - - except psutil.NoSuchProcess: - self._set_errored( - RuntimeError( - f"Engine process (pid {self._engine_process.pid}) died.")) - - except Exception as e: - self._set_errored(e) - - async def run_output_handler_loop(self): - """Get RequestOutputs from Engine and stream to Request Queues""" - - try: - while True: - # Poll, checking for ENGINE_DEAD - while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT - ) == 0: - logger.debug("Waiting for output from MQLLMEngine.") - - # If errored, alert all running requests. - if self.errored: - for queue_j in tuple(self.output_queues.values()): - queue_j.put_nowait( - ENGINE_DEAD_ERROR(self._errored_with)) - return - - message: Frame = await self.output_socket.recv(copy=False) - request_outputs = pickle.loads(message.buffer) - - is_error = isinstance(request_outputs, - (BaseException, RPCError)) - if is_error: - if isinstance(request_outputs, RPCError): - rpc_error: RPCError = request_outputs - request_id = rpc_error.request_id - exception = rpc_error.exception - is_engine_errored = rpc_error.is_engine_errored - else: - # MPLLMEngine should always return an RPCError to - # the output_socket when an issue arises. - # If we are here, we are in a bad state and - # should shut down the server. - error: BaseException = request_outputs - logger.error( - "Received Exception %s rather than RPCError from " - "MPLLMEngine. This should never happen.", error) - request_id = None - exception = error - is_engine_errored = True - - # Set to error state only on engine critical error - # (and record only the first one) - if is_engine_errored and not self._errored_with: - self._errored_with = exception - # If engine is errored, no matter the type of exception - # it will no longer be able to receive new requests, - # therefore we have to inform that the current - # processed requests failed as well. Send back a dead - # engine error give this feedback and also give a - # 'hint' to the server to shut down next. - exception = self.dead_error - - if request_id is None: - # If request_id is None, then the engine raised an - # exception for a batch, and we may not know the - # request that caused it, neither if it was actually - # caused by any of them (e.g. CUDA OOM). Therefore we - # broadcast the same exception for all requests. - for queue_i in tuple(self.output_queues.values()): - queue_i.put_nowait(exception) - else: - queue = self.output_queues.get(request_id) - if queue is not None: - queue.put_nowait(exception) - # Put each output into the appropriate queue. - elif isinstance( - request_outputs, - (RPCAdapterLoadedResponse, RPCIsSleepingResponse)): - self._add_output(request_outputs) - else: - for request_output in request_outputs: - self._add_output(request_output) - - except asyncio.CancelledError: - logger.debug("Shutting down MQLLMEngineClient output handler.") - - def _add_output(self, request_output: Union[RequestOutput, - RPCAdapterLoadedResponse, - RPCIsSleepingResponse]): - queue = self.output_queues.get(request_output.request_id) - if queue is not None: - queue.put_nowait(request_output) - - async def setup(self): - """Set up the client before it starts sending server requests.""" - - # Start output_loop - if self.output_loop is None: - # only generate once to avoid multiple concurrent output_loops - # this will lead to race conditions and wrong orders of tokens - # returned by the engine - # setup will be called multiple times during the startup of - # the engine - self.output_loop = asyncio.create_task( - self.run_output_handler_loop()) - - with self.get_data_socket() as socket: - # Wait until server is ready. - response = await self._wait_for_server_rpc(socket) - - self.tracing_flag = response.tracing_enabled - - # Start health_loop. - if self.health_loop is None: - self.health_loop = asyncio.create_task( - self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) - - def close(self): - """Destroy the ZeroMQ Context.""" - # Close all sockets and terminate the context. - self.context.destroy(linger=0) - - # Cancel background tasks. - if self.health_loop is not None: - self.health_loop.cancel() - if self.output_loop is not None: - self.output_loop.cancel() - - def _set_errored(self, e: BaseException): - logger.exception(repr(e)) - if self._errored_with is None: - self._errored_with = e - - @staticmethod - async def _send_get_data_rpc_request(request: RPCStartupRequest, - expected_type: Any, - error_message: str, - socket: Socket) -> Any: - """Send an RPC request that is expecting data back.""" - - # Ping RPCServer with a request. - await socket.send_multipart((pickle.dumps(request), ), copy=False) - - # Make sure the server responds in time. - if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: - raise TimeoutError("RPCServer didn't reply within " - f"{VLLM_RPC_TIMEOUT} ms") - - # Await the data from the Server. - frame = await socket.recv(copy=False) - data = pickle.loads(frame.buffer) - - if isinstance(data, BaseException): - raise data - elif not isinstance(data, expected_type): - raise ValueError(error_message) - - return data - - @staticmethod - async def _send_one_way_rpc_request(request: RPC_REQUEST_T, - socket: Socket): - """Send one-way RPC request to trigger an action.""" - - if socket.closed: - raise MQClientClosedError() - - await socket.send_multipart((pickle.dumps(request), )) - - async def _await_ack(self, error_message: str, socket: Socket): - """Await acknowledgement that a request succeeded.""" - - if socket.closed: - raise MQClientClosedError() - - if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: - raise TimeoutError("MQLLMEngine didn't reply within " - f"{VLLM_RPC_TIMEOUT}ms") - - await self._check_success(error_message, socket) - - @staticmethod - async def _check_success(error_message: str, socket: Socket): - """Confirm that socket has a VLLM_RPC_SUCCESS_STR message""" - - if socket.closed: - raise MQClientClosedError() - - frame = await socket.recv(copy=False) - response = pickle.loads(frame.buffer) - - # Raise error if unsuccessful - if isinstance(response, BaseException): - raise response - elif (not isinstance(response, str) - or response != VLLM_RPC_SUCCESS_STR): - raise ValueError(error_message) - - async def get_input_preprocessor(self) -> InputPreprocessor: - return self.input_preprocessor - - async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): - if self.tokenizer is None: - return None - else: - return await self.tokenizer.get_lora_tokenizer_async(lora_request) - - async def get_vllm_config(self) -> VllmConfig: - return self.vllm_config - - async def get_decoding_config(self) -> DecodingConfig: - return self.decoding_config - - async def get_model_config(self) -> ModelConfig: - return self.model_config - - async def is_tracing_enabled(self) -> bool: - return self.tracing_flag - - async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse: - """Wait for the RPCServer to start up.""" - - return await self._send_get_data_rpc_request( - request=RPCStartupRequest.IS_SERVER_READY, - expected_type=RPCStartupResponse, - error_message="Unable to start RPC Server", - socket=socket) - - async def abort(self, request_id: Union[str, Iterable[str]]): - """Send an ABORT_REQUEST signal to the RPC Server""" - - if not isinstance(request_id, str): - raise RuntimeError("Only single-request abort supported in" - " deprecated V0") - - with suppress(MQClientClosedError): - await self._send_one_way_rpc_request( - request=RPCAbortRequest(request_id), socket=self.input_socket) - - async def do_log_stats( - self, - scheduler_outputs: Optional[SchedulerOutputs] = None, - model_output: Optional[List[SamplerOutput]] = None, - ) -> None: - """ - Ignore do_log_stats (handled on MQLLMEngine polling) - """ - pass - - async def check_health(self): - """ - The check health loop probes the health status of the - Engine's health every N seconds and sets _errored_with - if the engine is unhealthy. - """ - if self._errored_with is not None: - raise self._errored_with - - @property - def is_running(self) -> bool: - return not self.errored - - @property - def is_stopped(self) -> bool: - return self.errored - - @property - def errored(self) -> bool: - return self._errored_with is not None - - @property - def dead_error(self) -> BaseException: - return ENGINE_DEAD_ERROR(self._errored_with) - - def generate( - self, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> AsyncGenerator[RequestOutput, None]: - """Generate outputs for a request. - - Generate outputs for a request. This method is a coroutine. It adds the - request into the waiting queue of the LLMEngine and streams the outputs - from the LLMEngine to the caller. - - Args: - prompt: The prompt to the LLM. See - [`PromptType`][vllm.inputs.PromptType] for more details about - the format of each input. - sampling_params: The sampling parameters of the request. - request_id: The unique id of the request. - lora_request: LoRA request to use for generation, if any. - trace_headers: OpenTelemetry trace headers. - priority: Priority of the request (lower means earlier handling). - Any priority other than 0 will lead to an error if the - scheduling policy is not "priority". - """ - return self._process_request(prompt, sampling_params, request_id, - lora_request, trace_headers, priority) - - def encode( - self, - prompt: PromptType, - pooling_params: PoolingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> AsyncGenerator[PoolingRequestOutput, None]: - raise NotImplementedError( - "Pooling models are not supported in vLLM V0") - - async def _process_request( - self, - prompt: PromptType, - params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> AsyncGenerator[RequestOutput, None]: - """Send an RPCGenerateRequest to the RPCServer and stream responses.""" - - # If already dead, error out. - if self._errored_with is not None: - raise ENGINE_DEAD_ERROR(self._errored_with) - - # Ensure the request id is unique among running requests - if request_id in self.output_queues: - raise ValueError(f"Request {request_id} already exists") - - # 1) Create output queue for this request. - queue: asyncio.Queue[Union[RequestOutput, - BaseException]] = asyncio.Queue() - self.output_queues[request_id] = queue - - try: - # 2) Detach logits processors so that they can be pickled - # separately (may require cloudpickle which is slower) - if params.logits_processors: - # Defensive shallow copy - params = copy.copy(params) - logits_processors = params.logits_processors - params.logits_processors = None - lp_bytes = cloudpickle.dumps(logits_processors) - else: - lp_bytes = None - - request_bytes = pickle.dumps( - RPCProcessRequest( - prompt=prompt, - params=params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - )) - - # 3) Send the RPCGenerateRequest to the MQLLMEngine. - parts = (request_bytes, - lp_bytes) if lp_bytes else (request_bytes, ) - await self.input_socket.send_multipart(parts, copy=False) - - # 4) Stream the RequestOutputs from the output queue. Note - # that the output_loop pushes RequestOutput objects to this - # queue after pulling them from the zmq socket. - finished = False - try: - while not finished: - request_output = await queue.get() - - if isinstance(request_output, BaseException): - raise request_output - - finished = request_output.finished - yield request_output - finally: - # Request was canceled by the client. - if not finished and not self.errored: - await self.abort(request_id) - finally: - self.output_queues.pop(request_id) - - async def start_profile(self) -> None: - """Start profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket) - - async def stop_profile(self) -> None: - """Stop profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) - - async def reset_mm_cache(self) -> None: - """Reset the multi-modal cache""" - - await self._send_one_way_rpc_request( - request=RPCResetMultiModalCacheRequest.RESET, - socket=self.input_socket) - - async def reset_prefix_cache(self, - device: Optional[Device] = None) -> None: - """Reset the prefix cache""" - - await self._send_one_way_rpc_request( - request=RPCResetPrefixCacheRequest(device), - socket=self.input_socket) - - async def sleep(self, level: int = 1) -> None: - """Sleep the engine for a given level""" - return await self._send_one_way_rpc_request( - request=RPCSleepRequest(level), socket=self.input_socket) - - async def wake_up(self, tags: Optional[list[str]] = None) -> None: - """Wake up the engine""" - return await self._send_one_way_rpc_request( - request=RPCWakeUpRequest(tags), socket=self.input_socket) - - async def is_sleeping(self) -> bool: - """Check whether the engine is sleeping""" - request = RPCIsSleepingRequest() - - queue: asyncio.Queue[Union[BaseException, - RPCIsSleepingResponse]] = asyncio.Queue() - self.output_queues[request.request_id] = queue - - request_bytes = pickle.dumps(request) - await self.input_socket.send_multipart((request_bytes, ), copy=False) - - request_output = await queue.get() - self.output_queues.pop(request.request_id) - - if isinstance(request_output, BaseException): - raise request_output - return request_output.is_sleeping - - async def add_lora(self, lora_request: LoRARequest) -> bool: - """Load a new LoRA adapter into the engine for future requests.""" - # Uses the same I/O as generate requests - request = RPCLoadAdapterRequest(lora_request) - - # Create output queue for this request. - queue: asyncio.Queue[Union[ - BaseException, RPCAdapterLoadedResponse]] = asyncio.Queue() - self.output_queues[request.request_id] = queue - - # Send the request - request_bytes = pickle.dumps(request) - await self.input_socket.send_multipart((request_bytes, ), copy=False) - - # Wait for the response - request_output = await queue.get() - self.output_queues.pop(request.request_id) - - # Raise on error, otherwise happily return None - if isinstance(request_output, BaseException): - raise request_output - return request_output.lora_loaded diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py deleted file mode 100644 index 138283d4c8a7..000000000000 --- a/vllm/engine/multiprocessing/engine.py +++ /dev/null @@ -1,470 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pickle -import signal -from contextlib import contextmanager -from typing import Iterator, List, Optional, Union - -import cloudpickle -import zmq - -from vllm import AsyncEngineArgs, SamplingParams -from vllm.config import VllmConfig -from vllm.engine.llm_engine import LLMEngine -# yapf conflicts with isort for this block -# yapf: disable -from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, - IPC_HEALTH_EXT, IPC_INPUT_EXT, - IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCAdapterLoadedResponse, RPCError, - RPCIsSleepingRequest, - RPCIsSleepingResponse, - RPCLoadAdapterRequest, - RPCProcessRequest, - RPCResetMultiModalCacheRequest, - RPCResetPrefixCacheRequest, - RPCSleepRequest, RPCStartupRequest, - RPCStartupResponse, - RPCUProfileRequest, RPCWakeUpRequest) -# yapf: enable -from vllm.logger import init_logger -from vllm.outputs import RequestOutput -from vllm.transformers_utils.config import ( - maybe_register_config_serialize_by_value) -from vllm.usage.usage_lib import UsageContext -from vllm.utils import deprecate_kwargs -from vllm.worker.model_runner_base import InputProcessingError - -logger = init_logger(__name__) - -POLLING_TIMEOUT_MS = 10000 -HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) - - -class MQLLMEngine: - """A multiprocessing wrapper for - [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. - - This class is used to wrap the - [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use - in concurrent manner. It runs a background loop and uses zeromq to - receive new requests and stream outputs incrementally via ipc. - - The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode - process is kicked off when a new RPCProcessRequest is received by the - input_socket. - - The self.engine_loop checks the input_socket for new requests, - adds them to the LLMEngine if there are any, calls the internal - [`LLMEngine.step()`][vllm.engine.llm_engine.LLMEngine.step], and sends - the RequestOutputs back over the output_socket. - - If use_async_sockets is set, the logic associated with reading new - requests from the socket and sending data to the socket is passed - as a callback to the llm_engine, which calls the logic asynchronously - such that the IPC can be overlapped with the GPU. - - Args: - ipc_path: Base path for zeromq interprocess messaging - use_async_sockets: Whether to make send/recv async with GPU - log_requests: Whether to log the requests. - *args: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. - **kwargs: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. - """ - - def __init__(self, - ipc_path: str, - use_async_sockets: bool, - *args, - log_requests: bool = True, - **kwargs) -> None: - # For MQLLMEngine, we can use cached outputs, since each new request - # output is immediately pickled and send over the socket, which frees - # the python object to be reused again. - kwargs['use_cached_outputs'] = True - - self.engine = LLMEngine(*args, **kwargs) - self.log_requests = log_requests - - self.use_async_sockets = use_async_sockets - if self.use_async_sockets: - self.engine.process_request_outputs_callback = \ - self._async_socket_engine_callback - - self.ctx = zmq.Context() # type: ignore[attr-defined] - - # Receive input from the client. - self.input_socket = self.ctx.socket(zmq.constants.PULL) - self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") - - # Send output stream back to client. - self.output_socket = self.ctx.socket(zmq.constants.PUSH) - self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") - - # Send heartbeats back to client. - self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) - self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") - - # IPC path for the data socket. - self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" - - # Error state. - self._errored_with: Optional[BaseException] = None - - @property - def dead_error(self) -> BaseException: - if self._errored_with is not None: - return ENGINE_DEAD_ERROR(self._errored_with) - else: - return ENGINE_DEAD_ERROR() - - @classmethod - @deprecate_kwargs( - "disable_log_requests", - additional_message=("This argument will have no effect. " - "Use `enable_log_requests` instead."), - ) - def from_vllm_config( - cls, - vllm_config: VllmConfig, - usage_context: UsageContext, - enable_log_requests: bool, - disable_log_stats: bool, - ipc_path: str, - disable_log_requests: bool = True, # Deprecated, will be removed - ) -> "MQLLMEngine": - # Setup plugins for each process - from vllm.plugins import load_general_plugins - load_general_plugins() - - use_async_sockets = vllm_config.model_config.use_async_output_proc - - return cls( - vllm_config=vllm_config, - executor_class=LLMEngine._get_executor_cls(vllm_config), - ipc_path=ipc_path, - usage_context=usage_context, - use_async_sockets=use_async_sockets, - log_requests=enable_log_requests, - log_stats=(not disable_log_stats), - ) - - @staticmethod - def from_engine_args(engine_args: AsyncEngineArgs, - usage_context: UsageContext, ipc_path: str): - """Creates an MQLLMEngine from the engine arguments.""" - - vllm_config = engine_args.create_engine_config(usage_context) - return MQLLMEngine.from_vllm_config( - ipc_path=ipc_path, - vllm_config=vllm_config, - usage_context=usage_context, - enable_log_requests=engine_args.enable_log_requests, - disable_log_stats=engine_args.disable_log_stats, - ) - - def start(self): - try: - try: - logger.debug("Starting Startup Loop.") - self.run_startup_loop() - logger.debug("Starting Engine Loop.") - self.run_engine_loop() - except Exception as e: - logger.exception(repr(e)) - except KeyboardInterrupt: - logger.debug("Shutting down MQLLMEngine.") - finally: - logger.debug("MQLLMEngine is shut down.") - self.cleanup() - - def cleanup(self): - """Cleanup zeromq state on shutdown.""" - # Closes all sockets and destroys context. - self.ctx.destroy(linger=0) - del self.engine - - @contextmanager - def make_data_socket( - self) -> Iterator[zmq.Socket]: # type: ignore[name-defined] - socket = self.ctx.socket(zmq.constants.ROUTER) - try: - socket.bind(self.data_ipc_path) - yield socket - finally: - socket.close(linger=0) - - def run_startup_loop(self) -> None: - """Startup loop for sending data from Engine -> Client.""" - - with self.make_data_socket() as socket: - response: Union[RPCStartupResponse, BaseException] - try: - identity, message = socket.recv_multipart(copy=False) - request: RPCStartupRequest = pickle.loads(message.buffer) - - # Handle the query from the Client. - if request == RPCStartupRequest.IS_SERVER_READY: - tracing_enabled = self.engine.is_tracing_enabled() - response = RPCStartupResponse( - tracing_enabled=tracing_enabled) - - except Exception as e: - response = e - - socket.send_multipart((identity, pickle.dumps(response)), - copy=False) - - def run_engine_loop(self): - """Core busy loop of the LLMEngine.""" - - while True: - if not self.engine.has_unfinished_requests(): - # Poll until there is work to do. - while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - # When there's no work, check on engine health and send - # health status back to client - self._health_check() - self.engine.do_log_stats() - logger.debug("Waiting for new requests in engine loop.") - - # Handle any input from the client. - self.handle_new_input() - - # Engine step. - request_outputs = self.engine_step() - - # Send request outputs (if async, done in engine_step callback). - if not self.use_async_sockets: - self._send_outputs(request_outputs) - - def engine_step(self) -> List[RequestOutput]: - """Engine step wrapper with error handling.""" - try: - return self.engine.step() - except SystemExit: - raise - except InputProcessingError as e: - # Special case where we handle an error preparing the inputs for - # a single request in the batch - rpc_err = RPCError(request_id=e.request_id, - is_engine_errored=False, - exception=e.__cause__) - self._send_outputs(rpc_err) - return [] - except BaseException as e: - self._set_errored(e) - rpc_err = RPCError(request_id=None, - is_engine_errored=True, - exception=e) - self._send_outputs(rpc_err) - raise e - - def handle_new_input(self): - """Handle new input from the socket""" - try: - while self.input_socket.poll(timeout=0) != 0: - frames = self.input_socket.recv_multipart(copy=False) - request = pickle.loads(frames[0].buffer) - - if isinstance(request, RPCProcessRequest): - if len(frames) > 1: - # Use cloudpickle for logits processors - assert isinstance(request.params, SamplingParams) - lprocs = cloudpickle.loads(frames[1].buffer) - request.params.logits_processors = lprocs - self._handle_process_request(request) - elif isinstance(request, RPCAbortRequest): - self._handle_abort_request(request) - elif isinstance(request, RPCUProfileRequest): - if request == RPCUProfileRequest.START_PROFILE: - self.start_profile() - else: - self.stop_profile() - elif isinstance(request, RPCLoadAdapterRequest): - self._handle_load_adapter_request(request) - elif isinstance(request, RPCResetMultiModalCacheRequest): - self.reset_mm_cache() - elif isinstance(request, RPCResetPrefixCacheRequest): - self.reset_prefix_cache() - elif isinstance(request, RPCSleepRequest): - self.sleep(request.value) - elif isinstance(request, RPCWakeUpRequest): - self.wake_up(request.tags) - elif isinstance(request, RPCIsSleepingRequest): - self._handle_is_sleeping_request(request) - else: - raise ValueError("Unknown RPCRequest Type: " - f"{type(request)}") - - except Exception as e: - self._set_errored(e) - self._send_unhealthy(e) - raise e from None - - def _handle_process_request(self, request: RPCProcessRequest): - """Handle RPCProcessRequest by adding it to the LLMEngine.""" - request_id = request.request_id - - if self._errored_with is not None: - rpc_err = RPCError(request_id=request_id, - is_engine_errored=True, - exception=ENGINE_DEAD_ERROR(self._errored_with)) - self._send_outputs(rpc_err) - - try: - self.engine.add_request(request_id=request_id, - prompt=request.prompt, - params=request.params, - lora_request=request.lora_request, - trace_headers=request.trace_headers, - priority=request.priority) - - if self.log_requests: - logger.info("Added request %s.", request.request_id) - - except Exception as e: - # We do not set self._errored = True here, since the error - # is due to an issue adding this request to the engine, - # rather than an issue with the engine itself. - logger.debug("Failed to add request %s to engine. %s", - request.request_id, e) - is_errored = self._errored_with is not None - rpc_err = RPCError(request_id=request_id, - is_engine_errored=is_errored, - exception=e) - self._send_outputs(rpc_err) - - # Remove request from the engine. - self.engine.abort_request(request_id) - - def _handle_abort_request(self, request: RPCAbortRequest): - self.engine.abort_request(request.request_id) - if self.log_requests: - logger.info("Aborted request %s.", request.request_id) - - def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest): - try: - lora_loaded = self.engine.add_lora(request.lora_request) - except BaseException as e: - # Send back an error if the adater fails to load - rpc_err = RPCError(request_id=request.request_id, - is_engine_errored=False, - exception=e) - self._send_outputs(rpc_err) - return - # Otherwise, send back the successful load message - self._send_outputs( - RPCAdapterLoadedResponse(request_id=request.request_id, - lora_loaded=lora_loaded)) - - def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest): - is_sleeping = self.is_sleeping() - self._send_outputs( - RPCIsSleepingResponse(request_id=request.request_id, - is_sleeping=is_sleeping)) - - def _health_check(self): - # Send unhealthy if engine has already errored - if self._errored_with is not None: - self._send_unhealthy(self._errored_with) - try: - self.engine.check_health() - self._send_healthy() - except Exception as e: - self._set_errored(e) - self._send_unhealthy(e) - - def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): - """Send outputs back to the engine client. These can be: - - Exceptions - - A list of generation outputs - - A response from loading a lora adapter - """ - if outputs: - try: - from ray.exceptions import RayTaskError - - # RayTaskError might not pickelable here. We need to unpack the - # underlying exception as the real exception in the output. - if (isinstance(outputs, RPCError) - and isinstance(outputs.exception, RayTaskError)): - outputs.exception = outputs.exception.cause - except ImportError: - pass - - output_bytes = pickle.dumps(outputs) - self.output_socket.send_multipart((output_bytes, ), copy=False) - - def _send_healthy(self): - """Send HEALTHY message to RPCClient.""" - if not self.heartbeat_socket.closed: - self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False) - - def _send_unhealthy(self, error: BaseException): - """Send UNHEALTHY message to RPCClient.""" - if not self.heartbeat_socket.closed: - error_bytes = pickle.dumps(error) - self.heartbeat_socket.send_multipart((error_bytes, ), copy=False) - - def _async_socket_engine_callback(self, - request_outputs: REQUEST_OUTPUTS_T): - """Callback used by engine to make socket handling async with GPU.""" - self._send_outputs(request_outputs) - self.handle_new_input() - - def _set_errored(self, e: BaseException): - """Log and set errored status if this is the first issue.""" - if self._errored_with is None: - self._errored_with = e - - def start_profile(self) -> None: - self.engine.start_profile() - - def stop_profile(self) -> None: - self.engine.stop_profile() - - def reset_mm_cache(self) -> bool: - return self.engine.reset_mm_cache() - - def reset_prefix_cache(self) -> bool: - return self.engine.reset_prefix_cache() - - def sleep(self, level: int = 1) -> None: - self.engine.sleep(level) - - def wake_up(self, tags: Optional[list[str]] = None) -> None: - self.engine.wake_up(tags) - - def is_sleeping(self) -> bool: - return self.engine.is_sleeping() - - -def signal_handler(*_) -> None: - raise KeyboardInterrupt("MQLLMEngine terminated") - - -def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext, - ipc_path: str, disable_log_stats: bool, - enable_log_requests: bool, engine_alive): - try: - # Ensure we can serialize transformer config before spawning - maybe_register_config_serialize_by_value() - - engine = MQLLMEngine.from_vllm_config( - vllm_config=vllm_config, - usage_context=usage_context, - disable_log_stats=disable_log_stats, - enable_log_requests=enable_log_requests, - ipc_path=ipc_path) - - signal.signal(signal.SIGTERM, signal_handler) - - engine.start() - - except BaseException as e: - logger.exception(e) - engine_alive.value = False - raise e from None diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 887e27710924..c3195dbc4697 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -12,7 +12,6 @@ from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError -from vllm.engine.multiprocessing import MQEngineDeadError from vllm.engine.protocol import EngineClient from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT, H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT) @@ -156,7 +155,6 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: @app.exception_handler(RuntimeError) @app.exception_handler(AsyncEngineDeadError) - @app.exception_handler(MQEngineDeadError) @app.exception_handler(EngineDeadError) @app.exception_handler(EngineGenerateError) async def runtime_exception_handler(request: Request, __): diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2e4aa7f3d5a6..527193c91339 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio -import atexit import gc import importlib import inspect @@ -17,7 +16,6 @@ from argparse import Namespace from collections.abc import AsyncGenerator, AsyncIterator, Awaitable from contextlib import asynccontextmanager -from functools import partial from http import HTTPStatus from typing import Annotated, Any, Callable, Optional @@ -42,8 +40,6 @@ from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore -from vllm.engine.multiprocessing.client import MQLLMEngineClient -from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (load_chat_template, resolve_hf_chat_template, @@ -102,13 +98,10 @@ log_non_default_args, with_cancellation) from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager -from vllm.transformers_utils.config import ( - maybe_register_config_serialize_by_value) from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import (Device, FlexibleArgumentParser, decorate_logs, - get_open_zmq_ipc_path, is_valid_ipv6_address, - set_ulimit) + is_valid_ipv6_address, set_ulimit) from vllm.v1.metrics.prometheus import get_prometheus_registry from vllm.version import __version__ as VLLM_VERSION @@ -237,8 +230,7 @@ async def build_async_engine_client_from_engine_args( async_llm.shutdown() # V0 AsyncLLM. - elif (MQLLMEngineClient.is_unsupported_config(vllm_config) - or disable_frontend_multiprocessing): + else: engine_client: Optional[EngineClient] = None try: @@ -252,96 +244,6 @@ async def build_async_engine_client_from_engine_args( if engine_client and hasattr(engine_client, "shutdown"): engine_client.shutdown() - # V0MQLLMEngine. - else: - if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: - # Make TemporaryDirectory for prometheus multiprocessing - # Note: global TemporaryDirectory will be automatically - # cleaned up upon exit. - global prometheus_multiproc_dir - prometheus_multiproc_dir = tempfile.TemporaryDirectory() - os.environ[ - "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name - else: - logger.warning( - "Found PROMETHEUS_MULTIPROC_DIR was set by user. " - "This directory must be wiped between vLLM runs or " - "you will find inaccurate metrics. Unset the variable " - "and vLLM will properly handle cleanup.") - - # Select random path for IPC. - ipc_path = get_open_zmq_ipc_path() - logger.debug("Multiprocessing frontend to use %s for IPC Path.", - ipc_path) - - # Start RPCServer in separate process (holds the LLMEngine). - # the current process might have CUDA context, - # so we need to spawn a new process - context = multiprocessing.get_context("spawn") - - # Ensure we can serialize transformer config before spawning - maybe_register_config_serialize_by_value() - - # The Process can raise an exception during startup, which may - # not actually result in an exitcode being reported. As a result - # we use a shared variable to communicate the information. - engine_alive = multiprocessing.Value('b', True, lock=False) - engine_process = context.Process( - target=run_mp_engine, - args=(vllm_config, UsageContext.OPENAI_API_SERVER, ipc_path, - engine_args.disable_log_stats, - engine_args.enable_log_requests, engine_alive)) - engine_process.start() - engine_pid = engine_process.pid - assert engine_pid is not None, "Engine process failed to start." - logger.info("Started engine process with PID %d", engine_pid) - - def _cleanup_ipc_path(): - socket_path = ipc_path.replace("ipc://", "") - if os.path.exists(socket_path): - os.remove(socket_path) - - # Ensure we clean up the local IPC socket file on exit. - atexit.register(_cleanup_ipc_path) - - # Build RPCClient, which conforms to EngineClient Protocol. - build_client = partial(MQLLMEngineClient, ipc_path, vllm_config, - engine_pid) - mq_engine_client = await asyncio.get_running_loop().run_in_executor( - None, build_client) - try: - while True: - try: - await mq_engine_client.setup() - break - except TimeoutError: - if (not engine_process.is_alive() - or not engine_alive.value): - raise RuntimeError( - "Engine process failed to start. See stack " - "trace for the root cause.") from None - - yield mq_engine_client # type: ignore[misc] - finally: - # Ensure rpc server process was terminated - engine_process.terminate() - - # Close all open connections to the backend - mq_engine_client.close() - - # Wait for engine process to join - engine_process.join(4) - if engine_process.exitcode is None: - # Kill if taking longer than 5 seconds to stop - engine_process.kill() - - # Lazy import for prometheus multiprocessing. - # We need to set PROMETHEUS_MULTIPROC_DIR environment variable - # before prometheus_client is imported. - # See https://prometheus.github.io/client_python/multiprocess/ - from prometheus_client import multiprocess - multiprocess.mark_process_dead(engine_process.pid) - async def validate_json_request(raw_request: Request): content_type = raw_request.headers.get("content-type", "").lower() diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index bb8bff48c7b9..4f540fe965e2 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -191,7 +191,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, has_sink) -> str: if use_mla: - from vllm.attention.backends.rocm_aiter_mla import ( + from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( is_aiter_mla_enabled) if selected_backend is None: