2020from concurrent .futures import ThreadPoolExecutor , wait
2121from typing import Any , Literal , TYPE_CHECKING
2222
23- import ray
2423
2524import torch
2625
27- from ray .util .placement_group import placement_group , remove_placement_group
28- from ray .util .scheduling_strategies import PlacementGroupSchedulingStrategy
2926from torchrl ._utils import logger as torchrl_logger
3027
3128# Import RLvLLMEngine and shared utilities
4340TIMEOUT_SECONDS = os .getenv ("TORCHRL_VLLM_TIMEOUT_SECONDS" , 300 )
4441
4542
43+ def _get_ray ():
44+ """Import Ray on demand to avoid global import side-effects.
45+
46+ Returns:
47+ ModuleType: The imported Ray module.
48+
49+ Raises:
50+ ImportError: If Ray is not installed.
51+ """
52+ try :
53+ import ray # type: ignore
54+
55+ return ray
56+ except Exception as e : # pragma: no cover - surfaced to callers
57+ raise ImportError (
58+ "ray is not installed. Please install it with `pip install ray`."
59+ ) from e
60+
4661class _AsyncvLLMWorker :
4762 """Async vLLM worker for Ray with weight update capabilities.
4863
@@ -267,7 +282,7 @@ async def generate(
267282 "vllm is not installed. Please install it with `pip install vllm`."
268283 )
269284
270- from vllm import RequestOutput , SamplingParams , TokensPrompt
285+ from vllm import SamplingParams , TokensPrompt
271286
272287 # Track whether input was originally a single prompt
273288 single_prompt_input = False
@@ -474,11 +489,7 @@ def _gpus_per_replica(engine_args: AsyncEngineArgs) -> int:
474489 )
475490
476491
477- # Create Ray remote versions
478- if ray is not None and _has_vllm :
479- _AsyncLLMEngineActor = ray .remote (num_cpus = 0 , num_gpus = 0 )(_AsyncLLMEngine )
480- else :
481- _AsyncLLMEngineActor = None
492+ # Ray actor wrapper is created lazily in __init__ to avoid global Ray import.
482493
483494
484495class AsyncVLLM (RLvLLMEngine ):
@@ -583,17 +594,18 @@ def __init__(
583594 raise ImportError (
584595 "vllm is not installed. Please install it with `pip install vllm`."
585596 )
586- if ray is None :
587- raise ImportError (
588- "ray is not installed. Please install it with `pip install ray`."
589- )
597+ # Lazily import ray only when constructing the actor class to avoid global import
590598
591599 # Enable prefix caching by default for better performance
592600 engine_args .enable_prefix_caching = enable_prefix_caching
593601
594602 self .engine_args = engine_args
595603 self .num_replicas = num_replicas
596- self .actor_class = actor_class or _AsyncLLMEngineActor
604+ if actor_class is None :
605+ ray = _get_ray ()
606+ self .actor_class = ray .remote (num_cpus = 0 , num_gpus = 0 )(_AsyncLLMEngine )
607+ else :
608+ self .actor_class = actor_class
597609 self .actors : list = []
598610 self ._launched = False
599611 self ._service_id = uuid .uuid4 ().hex [
@@ -608,6 +620,11 @@ def _launch(self):
608620 torchrl_logger .warning ("AsyncVLLMEngineService already launched" )
609621 return
610622
623+ # Local imports to avoid global Ray dependency
624+ ray = _get_ray ()
625+ from ray .util .placement_group import placement_group
626+ from ray .util .scheduling_strategies import PlacementGroupSchedulingStrategy
627+
611628 torchrl_logger .info (
612629 f"Launching { self .num_replicas } async vLLM engine actors..."
613630 )
@@ -938,6 +955,7 @@ def generate(
938955 Returns:
939956 RequestOutput | list[RequestOutput]: Generated outputs from vLLM.
940957 """
958+ ray = _get_ray ()
941959 # Check if this is a batch request
942960 if self ._is_batch (prompts , prompt_token_ids ):
943961 # Handle batched input by unbinding and sending individual requests
@@ -1062,6 +1080,9 @@ def shutdown(self):
10621080 f"Shutting down { len (self .actors )} async vLLM engine actors..."
10631081 )
10641082
1083+ ray = _get_ray ()
1084+ from ray .util .placement_group import remove_placement_group
1085+
10651086 # Kill all actors
10661087 for i , actor in enumerate (self .actors ):
10671088 try :
@@ -1254,6 +1275,7 @@ def _update_weights_with_nccl_broadcast_simple(
12541275 )
12551276
12561277 updated_weights = 0
1278+ ray = _get_ray ()
12571279 with torch .cuda .device (0 ): # Ensure we're on the correct CUDA device
12581280 for name , weight in gpu_weights .items ():
12591281 # Convert dtype to string name (like periodic-mono)
@@ -1330,6 +1352,7 @@ def get_num_unfinished_requests(
13301352 "AsyncVLLM service must be launched before getting request counts"
13311353 )
13321354
1355+ ray = _get_ray ()
13331356 if actor_index is not None :
13341357 if not (0 <= actor_index < len (self .actors )):
13351358 raise IndexError (
@@ -1360,6 +1383,7 @@ def get_cache_usage(self, actor_index: int | None = None) -> float | list[float]
13601383 "AsyncVLLM service must be launched before getting cache usage"
13611384 )
13621385
1386+ ray = _get_ray ()
13631387 if actor_index is not None :
13641388 if not (0 <= actor_index < len (self .actors )):
13651389 raise IndexError (
@@ -1672,6 +1696,7 @@ def _select_by_requests(self) -> int:
16721696 futures = [
16731697 actor .get_num_unfinished_requests .remote () for actor in self .actors
16741698 ]
1699+ ray = _get_ray ()
16751700 request_counts = ray .get (futures )
16761701
16771702 # Find the actor with minimum pending requests
@@ -1699,6 +1724,7 @@ def _select_by_cache_usage(self) -> int:
16991724 else :
17001725 # Query actors directly
17011726 futures = [actor .get_cache_usage .remote () for actor in self .actors ]
1727+ ray = _get_ray ()
17021728 cache_usages = ray .get (futures )
17031729
17041730 # Find the actor with minimum cache usage
@@ -1838,7 +1864,8 @@ def _is_actor_overloaded(self, actor_index: int) -> bool:
18381864 futures = [
18391865 actor .get_num_unfinished_requests .remote () for actor in self .actors
18401866 ]
1841- request_counts = ray .get (futures )
1867+ ray = _get_ray ()
1868+ request_counts = ray .get (futures )
18421869
18431870 if not request_counts :
18441871 return False
@@ -1887,8 +1914,9 @@ def get_stats(self) -> dict[str, Any]:
18871914 cache_futures = [
18881915 actor .get_cache_usage .remote () for actor in self .actors
18891916 ]
1890- request_counts = ray .get (request_futures )
1891- cache_usages = ray .get (cache_futures )
1917+ ray = _get_ray ()
1918+ request_counts = ray .get (request_futures )
1919+ cache_usages = ray .get (cache_futures )
18921920
18931921 for i , (requests , cache_usage ) in enumerate (
18941922 zip (request_counts , cache_usages )
0 commit comments