Skip to content

Commit

Permalink
perf: move large vLLM imports into the image.imports() context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
sambarnes committed Jan 20, 2024
1 parent 5b61024 commit 6b01af7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 17 deletions.
13 changes: 2 additions & 11 deletions modal/runner/containers/vllm_unified.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,15 @@

import modal.gpu
import sentry_sdk
from modal import Image

from runner.engines.vllm import VllmEngine, VllmParams
from runner.engines.vllm import VllmEngine, VllmParams, vllm_image
from runner.shared.common import stub
from shared.logging import (
add_observability,
get_logger,
get_observability_secrets,
)
from shared.volumes import does_model_exist, models_path, models_volume

_vllm_image = add_observability(
Image.from_registry(
"nvidia/cuda:12.1.0-base-ubuntu22.04",
add_python="3.10",
).pip_install("vllm==0.2.6", "sentry-sdk==1.39.1")
)


def _make_container(
name: str, num_gpus: int = 1, memory: int = 0, concurrent_inputs: int = 8
Expand Down Expand Up @@ -74,7 +65,7 @@ def __init__(

wrap = stub.cls(
volumes={models_path: models_volume},
image=_vllm_image,
image=vllm_image,
# Default CPU memory is 128 on modal. Request more memory for larger
# windows of vLLM's batch loading weights into GPU memory.
memory=1024,
Expand Down
24 changes: 18 additions & 6 deletions modal/runner/engines/vllm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Optional

from modal import method
from modal import Image, method
from pydantic import BaseModel

from shared.logging import get_logger, timer
from shared.logging import (
add_observability,
get_logger,
timer,
)
from shared.protocol import (
CompletionPayload,
create_error_text,
Expand All @@ -16,6 +20,18 @@
logger = get_logger(__name__)


vllm_image = add_observability(
Image.from_registry(
"nvidia/cuda:12.1.0-base-ubuntu22.04",
add_python="3.10",
).pip_install("vllm==0.2.6", "sentry-sdk==1.39.1")
)

with vllm_image.imports():
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine


# Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py#L192
class VllmParams(BaseModel):
model: str
Expand Down Expand Up @@ -44,10 +60,6 @@ class VllmParams(BaseModel):

class VllmEngine(BaseEngine):
def __init__(self, params: VllmParams):
with timer("imports"):
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine

self.engine_args = AsyncEngineArgs(
**params.dict(),
disable_log_requests=True,
Expand Down

0 comments on commit 6b01af7

Please sign in to comment.