diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 24ed5d392839..7ebccdb5caed 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -118,14 +118,16 @@ def model(x): @fork_new_process_for_each_test @pytest.mark.parametrize( - "model", + "model, use_v1", [ # sleep mode with safetensors - f"{MODEL_WEIGHTS_S3_BUCKET}/Llama-3.2-1B", + (f"{MODEL_WEIGHTS_S3_BUCKET}/Llama-3.2-1B", True), # sleep mode with pytorch checkpoint - "facebook/opt-125m" + ("facebook/opt-125m", False), ]) -def test_end_to_end(model): +def test_end_to_end(model: str, use_v1: bool): + import os + os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0" free, total = torch.cuda.mem_get_info() used_bytes_baseline = total - free # in case other process is running load_format = LoadFormat.AUTO @@ -152,3 +154,5 @@ def test_end_to_end(model): # cmp output assert output[0].outputs[0].text == output2[0].outputs[0].text + + del os.environ["VLLM_USE_V1"] diff --git a/tests/entrypoints/openai/test_sleep.py b/tests/entrypoints/openai/test_sleep.py new file mode 100644 index 000000000000..1caa743c4018 --- /dev/null +++ b/tests/entrypoints/openai/test_sleep.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 + +import requests + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "meta-llama/Llama-3.2-1B" + + +def test_sleep_mode(): + # dtype, max-len etc set so that this can run in CI + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--max-num-seqs", + "128", + "--enable-sleep-mode", + ] + + with RemoteOpenAIServer(MODEL_NAME, + args, + env_dict={ + "VLLM_SERVER_DEV_MODE": "1", + "CUDA_VISIBLE_DEVICES": "0" + }) as remote_server: + response = requests.post(remote_server.url_for("/sleep"), + data={"level": "1"}) + assert response.status_code == 200 + response = requests.post(remote_server.url_for("/wake_up")) + assert response.status_code == 200 diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 053635a28638..93d9b74d8e1e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1187,6 +1187,12 @@ async def stop_profile(self) -> None: async def reset_prefix_cache(self) -> None: self.engine.reset_prefix_cache() + async def sleep(self, level: int = 1) -> None: + self.engine.sleep(level) + + async def wake_up(self) -> None: + self.engine.wake_up() + async def add_lora(self, lora_request: LoRARequest) -> None: self.engine.add_lora(lora_request) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 3cf1850ee65a..26dfb63c3dbf 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -127,6 +127,15 @@ class RPCResetPrefixCacheRequest(Enum): RESET_PREFIX_CACHE = 1 +class RPCSleepRequest(Enum): + SLEEP_LEVEL_1 = 1 + SLEEP_LEVEL_2 = 2 + + +class RPCWakeUpRequest(Enum): + WAKE_UP = 1 + + @dataclass class RPCLoadAdapterRequest: lora_request: LoRARequest @@ -141,7 +150,8 @@ class RPCAdapterLoadedResponse: RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, RPCUProfileRequest, RPCLoadAdapterRequest, - RPCResetPrefixCacheRequest] + RPCResetPrefixCacheRequest, RPCSleepRequest, + RPCWakeUpRequest] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse, RPCError] diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 85b5f31e3a4a..c12fe242082b 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -31,8 +31,9 @@ RPCLoadAdapterRequest, RPCProcessRequest, RPCResetPrefixCacheRequest, - RPCStartupRequest, RPCStartupResponse, - RPCUProfileRequest) + RPCSleepRequest, RPCStartupRequest, + RPCStartupResponse, + RPCUProfileRequest, RPCWakeUpRequest) from vllm.engine.protocol import EngineClient # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT @@ -685,6 +686,16 @@ async def reset_prefix_cache(self) -> None: request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE, socket=self.input_socket) + async def sleep(self, level: int = 1) -> None: + """Sleep the engine for a given level""" + return await self._send_one_way_rpc_request( + request=RPCSleepRequest(level), socket=self.input_socket) + + async def wake_up(self) -> None: + """Wake up the engine""" + return await self._send_one_way_rpc_request( + request=RPCWakeUpRequest.WAKE_UP, socket=self.input_socket) + async def add_lora(self, lora_request: LoRARequest) -> None: """Load a new LoRA adapter into the engine for future requests.""" # Uses the same I/O as generate requests diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index a0dd79586588..ce24aa21514d 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -20,8 +20,9 @@ RPCLoadAdapterRequest, RPCProcessRequest, RPCResetPrefixCacheRequest, - RPCStartupRequest, RPCStartupResponse, - RPCUProfileRequest) + RPCSleepRequest, RPCStartupRequest, + RPCStartupResponse, + RPCUProfileRequest, RPCWakeUpRequest) # yapf: enable from vllm.logger import init_logger from vllm.outputs import RequestOutput @@ -242,6 +243,10 @@ def handle_new_input(self): self._handle_load_adapter_request(request) elif isinstance(request, RPCResetPrefixCacheRequest): self.reset_prefix_cache() + elif isinstance(request, RPCSleepRequest): + self.sleep(request.value) + elif isinstance(request, RPCWakeUpRequest): + self.wake_up() else: raise ValueError("Unknown RPCRequest Type: " f"{type(request)}") @@ -369,6 +374,12 @@ def stop_profile(self) -> None: def reset_prefix_cache(self) -> bool: return self.engine.reset_prefix_cache() + def sleep(self, level: int = 1) -> None: + self.engine.sleep(level) + + def wake_up(self) -> None: + self.engine.wake_up() + def signal_handler(*_) -> None: raise KeyboardInterrupt("MQLLMEngine terminated") diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index d1112558666f..ee9accd32f21 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -278,6 +278,16 @@ async def reset_prefix_cache(self) -> None: """Reset the prefix cache""" ... + @abstractmethod + async def sleep(self, level: int = 1) -> None: + """Sleep the engine""" + ... + + @abstractmethod + async def wake_up(self) -> None: + """Wake up the engine""" + ... + @abstractmethod async def add_lora(self, lora_request: LoRARequest) -> None: """Load a new LoRA adapter into the engine for future requests.""" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0de7e2392691..f7162fadbce8 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -625,6 +625,24 @@ async def reset_prefix_cache(raw_request: Request): await engine_client(raw_request).reset_prefix_cache() return Response(status_code=200) + @router.post("/sleep") + async def sleep(raw_request: Request): + # get POST params + level = raw_request.query_params.get("level", "1") + logger.info("sleep the engine with level %s", level) + await engine_client(raw_request).sleep(int(level)) + # FIXME: in v0 with frontend multiprocessing, the sleep command + # is sent but does not finish yet when we return a response. + return Response(status_code=200) + + @router.post("/wake_up") + async def wake_up(raw_request: Request): + logger.info("wake up the engine") + await engine_client(raw_request).wake_up() + # FIXME: in v0 with frontend multiprocessing, the wake-up command + # is sent but does not finish yet when we return a response. + return Response(status_code=200) + @router.post("/invocations", dependencies=[Depends(validate_json_request)]) async def invocations(raw_request: Request): diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index da4930e0e2d8..0bedb5718a4b 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -295,6 +295,7 @@ async def create_transcription( # TODO(rob): figure out a way to pipe streaming in. # Non-streaming response. try: + assert result_generator is not None async for op in result_generator: result = op return TranscriptionResponse(text=result.outputs[0].text) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 1920dbf7a7dc..670454c283da 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -361,6 +361,12 @@ async def stop_profile(self) -> None: async def reset_prefix_cache(self) -> None: await self.engine_core.reset_prefix_cache_async() + async def sleep(self, level: int = 1) -> None: + await self.engine_core.sleep_async(level) + + async def wake_up(self) -> None: + await self.engine_core.wake_up_async() + async def add_lora(self, lora_request: LoRARequest) -> None: """Load a new LoRA adapter into the engine for future requests.""" await self.engine_core.add_lora_async(lora_request) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 66e252b7ccb0..03825d6ea430 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -213,6 +213,12 @@ def profile(self, is_start: bool = True): def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() + def sleep(self, level: int = 1): + self.model_executor.sleep(level) + + def wake_up(self): + self.model_executor.wake_up() + def add_lora(self, lora_request: LoRARequest) -> None: self.model_executor.add_lora(lora_request) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 8641833e438b..cfe554f7150b 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -81,6 +81,12 @@ def profile(self, is_start: bool = True) -> None: def reset_prefix_cache(self) -> None: raise NotImplementedError + def sleep(self, level: int = 1) -> None: + raise NotImplementedError + + def wake_up(self) -> None: + raise NotImplementedError + def abort_requests(self, request_ids: List[str]) -> None: raise NotImplementedError @@ -99,6 +105,12 @@ async def profile_async(self, is_start: bool = True) -> None: async def reset_prefix_cache_async(self) -> None: raise NotImplementedError + async def sleep_async(self, level: int = 1) -> None: + raise NotImplementedError + + async def wake_up_async(self) -> None: + raise NotImplementedError + async def abort_requests_async(self, request_ids: List[str]) -> None: raise NotImplementedError @@ -138,6 +150,12 @@ def profile(self, is_start: bool = True) -> None: def reset_prefix_cache(self) -> None: self.engine_core.reset_prefix_cache() + def sleep(self, level: int = 1) -> None: + self.engine_core.sleep(level) + + def wake_up(self) -> None: + self.engine_core.wake_up() + def add_lora(self, lora_request: LoRARequest) -> None: self.engine_core.add_lora(lora_request) @@ -303,6 +321,12 @@ def reset_prefix_cache(self) -> None: def add_lora(self, lora_request: LoRARequest) -> None: self._call_utility("add_lora", lora_request) + def sleep(self, level: int = 1) -> None: + self._call_utility("sleep", level) + + def wake_up(self) -> None: + self._call_utility("wake_up") + class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" @@ -380,5 +404,11 @@ async def profile_async(self, is_start: bool = True) -> None: async def reset_prefix_cache_async(self) -> None: await self._call_utility_async("reset_prefix_cache") + async def sleep_async(self, level: int = 1) -> None: + await self._call_utility_async("sleep", level) + + async def wake_up_async(self) -> None: + await self._call_utility_async("wake_up") + async def add_lora_async(self, lora_request: LoRARequest) -> None: await self._call_utility_async("add_lora", lora_request) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index c9a4c5369dfd..6b7de4deed39 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -169,6 +169,12 @@ def stop_profile(self): def reset_prefix_cache(self): self.engine_core.reset_prefix_cache() + def sleep(self, level: int = 1): + self.engine_core.sleep(level) + + def wake_up(self): + self.engine_core.wake_up() + def get_tokenizer_group( self, group_type: Type[_G] = BaseTokenizerGroup,