diff --git a/modal/runner/engines/vllm.py b/modal/runner/engines/vllm.py index 9a7110f..6d6f32d 100644 --- a/modal/runner/engines/vllm.py +++ b/modal/runner/engines/vllm.py @@ -1,6 +1,6 @@ from typing import Optional -from modal import method +from modal import enter, method from pydantic import BaseModel from shared.protocol import ( @@ -44,12 +44,17 @@ def __init__(self, params: VllmParams): from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine - engine_args = AsyncEngineArgs( + self.engine_args = AsyncEngineArgs( **params.dict(), disable_log_requests=True, ) + self.engine: AsyncLLMEngine | None = None - self.engine = AsyncLLMEngine.from_engine_args(engine_args) + @enter() + def start(self): + from vllm.engine.async_llm_engine import AsyncLLMEngine + + self.engine = AsyncLLMEngine.from_engine_args(self.engine_args) # @method() # async def tokenize_prompt(self, payload: Payload) -> List[int]: @@ -62,6 +67,8 @@ def __init__(self, params: VllmParams): @method() async def generate(self, payload: CompletionPayload, params): + assert self.engine is not None, "Engine not initialized" + try: import time