Skip to content

Commit d6e349d

Browse files
committed
Update
[ghstack-poisoned]
2 parents e4ce1d2 + 58f7ac7 commit d6e349d

File tree

2 files changed

+54
-20
lines changed

2 files changed

+54
-20
lines changed

.github/unittest/llm/scripts_llm/setup_env.sh

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66
# Do not install PyTorch and torchvision here, otherwise they also get cached.
77

88
set -e
9-
apt-get update && apt-get upgrade -y && apt-get install -y git cmake
9+
export DEBIAN_FRONTEND=noninteractive
10+
export TZ=UTC
11+
apt-get update
12+
apt-get install -yq --no-install-recommends git cmake
1013
# Avoid error: "fatal: unsafe repository"
1114
git config --global --add safe.directory '*'
12-
apt-get install -y wget \
15+
apt-get install -yq --no-install-recommends wget \
1316
gcc \
1417
g++ \
1518
unzip \
@@ -27,7 +30,10 @@ apt-get install -y wget \
2730
libgles2
2831

2932
# Upgrade specific package
30-
apt-get upgrade -y libstdc++6
33+
apt-get install -yq --no-install-recommends --only-upgrade libstdc++6
34+
35+
apt-get clean
36+
rm -rf /var/lib/apt/lists/*
3137

3238
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
3339
root_dir="$(git rev-parse --show-toplevel)"

torchrl/modules/llm/backends/vllm/vllm_async.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,9 @@
2020
from concurrent.futures import ThreadPoolExecutor, wait
2121
from typing import Any, Literal, TYPE_CHECKING
2222

23-
import ray
2423

2524
import torch
2625

27-
from ray.util.placement_group import placement_group, remove_placement_group
28-
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
2926
from torchrl._utils import logger as torchrl_logger
3027

3128
# Import RLvLLMEngine and shared utilities
@@ -43,6 +40,24 @@
4340
TIMEOUT_SECONDS = os.getenv("TORCHRL_VLLM_TIMEOUT_SECONDS", 300)
4441

4542

43+
def _get_ray():
44+
"""Import Ray on demand to avoid global import side-effects.
45+
46+
Returns:
47+
ModuleType: The imported Ray module.
48+
49+
Raises:
50+
ImportError: If Ray is not installed.
51+
"""
52+
try:
53+
import ray # type: ignore
54+
55+
return ray
56+
except Exception as e: # pragma: no cover - surfaced to callers
57+
raise ImportError(
58+
"ray is not installed. Please install it with `pip install ray`."
59+
) from e
60+
4661
class _AsyncvLLMWorker:
4762
"""Async vLLM worker for Ray with weight update capabilities.
4863
@@ -267,7 +282,7 @@ async def generate(
267282
"vllm is not installed. Please install it with `pip install vllm`."
268283
)
269284

270-
from vllm import RequestOutput, SamplingParams, TokensPrompt
285+
from vllm import SamplingParams, TokensPrompt
271286

272287
# Track whether input was originally a single prompt
273288
single_prompt_input = False
@@ -474,11 +489,7 @@ def _gpus_per_replica(engine_args: AsyncEngineArgs) -> int:
474489
)
475490

476491

477-
# Create Ray remote versions
478-
if ray is not None and _has_vllm:
479-
_AsyncLLMEngineActor = ray.remote(num_cpus=0, num_gpus=0)(_AsyncLLMEngine)
480-
else:
481-
_AsyncLLMEngineActor = None
492+
# Ray actor wrapper is created lazily in __init__ to avoid global Ray import.
482493

483494

484495
class AsyncVLLM(RLvLLMEngine):
@@ -583,17 +594,18 @@ def __init__(
583594
raise ImportError(
584595
"vllm is not installed. Please install it with `pip install vllm`."
585596
)
586-
if ray is None:
587-
raise ImportError(
588-
"ray is not installed. Please install it with `pip install ray`."
589-
)
597+
# Lazily import ray only when constructing the actor class to avoid global import
590598

591599
# Enable prefix caching by default for better performance
592600
engine_args.enable_prefix_caching = enable_prefix_caching
593601

594602
self.engine_args = engine_args
595603
self.num_replicas = num_replicas
596-
self.actor_class = actor_class or _AsyncLLMEngineActor
604+
if actor_class is None:
605+
ray = _get_ray()
606+
self.actor_class = ray.remote(num_cpus=0, num_gpus=0)(_AsyncLLMEngine)
607+
else:
608+
self.actor_class = actor_class
597609
self.actors: list = []
598610
self._launched = False
599611
self._service_id = uuid.uuid4().hex[
@@ -608,6 +620,11 @@ def _launch(self):
608620
torchrl_logger.warning("AsyncVLLMEngineService already launched")
609621
return
610622

623+
# Local imports to avoid global Ray dependency
624+
ray = _get_ray()
625+
from ray.util.placement_group import placement_group
626+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
627+
611628
torchrl_logger.info(
612629
f"Launching {self.num_replicas} async vLLM engine actors..."
613630
)
@@ -938,6 +955,7 @@ def generate(
938955
Returns:
939956
RequestOutput | list[RequestOutput]: Generated outputs from vLLM.
940957
"""
958+
ray = _get_ray()
941959
# Check if this is a batch request
942960
if self._is_batch(prompts, prompt_token_ids):
943961
# Handle batched input by unbinding and sending individual requests
@@ -1062,6 +1080,9 @@ def shutdown(self):
10621080
f"Shutting down {len(self.actors)} async vLLM engine actors..."
10631081
)
10641082

1083+
ray = _get_ray()
1084+
from ray.util.placement_group import remove_placement_group
1085+
10651086
# Kill all actors
10661087
for i, actor in enumerate(self.actors):
10671088
try:
@@ -1254,6 +1275,7 @@ def _update_weights_with_nccl_broadcast_simple(
12541275
)
12551276

12561277
updated_weights = 0
1278+
ray = _get_ray()
12571279
with torch.cuda.device(0): # Ensure we're on the correct CUDA device
12581280
for name, weight in gpu_weights.items():
12591281
# Convert dtype to string name (like periodic-mono)
@@ -1330,6 +1352,7 @@ def get_num_unfinished_requests(
13301352
"AsyncVLLM service must be launched before getting request counts"
13311353
)
13321354

1355+
ray = _get_ray()
13331356
if actor_index is not None:
13341357
if not (0 <= actor_index < len(self.actors)):
13351358
raise IndexError(
@@ -1360,6 +1383,7 @@ def get_cache_usage(self, actor_index: int | None = None) -> float | list[float]
13601383
"AsyncVLLM service must be launched before getting cache usage"
13611384
)
13621385

1386+
ray = _get_ray()
13631387
if actor_index is not None:
13641388
if not (0 <= actor_index < len(self.actors)):
13651389
raise IndexError(
@@ -1672,6 +1696,7 @@ def _select_by_requests(self) -> int:
16721696
futures = [
16731697
actor.get_num_unfinished_requests.remote() for actor in self.actors
16741698
]
1699+
ray = _get_ray()
16751700
request_counts = ray.get(futures)
16761701

16771702
# Find the actor with minimum pending requests
@@ -1699,6 +1724,7 @@ def _select_by_cache_usage(self) -> int:
16991724
else:
17001725
# Query actors directly
17011726
futures = [actor.get_cache_usage.remote() for actor in self.actors]
1727+
ray = _get_ray()
17021728
cache_usages = ray.get(futures)
17031729

17041730
# Find the actor with minimum cache usage
@@ -1838,7 +1864,8 @@ def _is_actor_overloaded(self, actor_index: int) -> bool:
18381864
futures = [
18391865
actor.get_num_unfinished_requests.remote() for actor in self.actors
18401866
]
1841-
request_counts = ray.get(futures)
1867+
ray = _get_ray()
1868+
request_counts = ray.get(futures)
18421869

18431870
if not request_counts:
18441871
return False
@@ -1887,8 +1914,9 @@ def get_stats(self) -> dict[str, Any]:
18871914
cache_futures = [
18881915
actor.get_cache_usage.remote() for actor in self.actors
18891916
]
1890-
request_counts = ray.get(request_futures)
1891-
cache_usages = ray.get(cache_futures)
1917+
ray = _get_ray()
1918+
request_counts = ray.get(request_futures)
1919+
cache_usages = ray.get(cache_futures)
18921920

18931921
for i, (requests, cache_usage) in enumerate(
18941922
zip(request_counts, cache_usages)

0 commit comments

Comments
 (0)