Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Avoid unnecessary Ray import warnings #6079

Merged
merged 1 commit into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,11 +655,13 @@ def __init__(

from vllm.executor import ray_utils
backend = "mp"
ray_found = ray_utils.ray is not None
ray_found = ray_utils.ray_is_available()
if cuda_device_count_stateless() < self.world_size:
if not ray_found:
raise ValueError("Unable to load Ray which is "
"required for multi-node inference")
"required for multi-node inference, "
"please install Ray with `pip install "
"ray`.") from ray_utils.ray_import_err
backend = "ray"
elif ray_found:
if self.placement_group:
Expand Down Expand Up @@ -691,6 +693,9 @@ def _verify_args(self) -> None:
raise ValueError(
"Unrecognized distributed executor backend. Supported values "
"are 'ray' or 'mp'.")
if self.distributed_executor_backend == "ray":
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
if not self.disable_custom_all_reduce and self.world_size > 1:
if is_hip():
self.disable_custom_all_reduce = True
Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,11 @@ def from_engine_args(
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()

if engine_args.engine_use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()

distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)

Expand Down
23 changes: 16 additions & 7 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,26 @@ def execute_model_compiled_dag_remote(self, ignored):
output = pickle.dumps(output)
return output

ray_import_err = None

except ImportError as e:
logger.warning(
"Failed to import Ray with %r. For multi-node inference, "
"please install Ray with `pip install ray`.", e)
ray = None # type: ignore
ray_import_err = e
RayWorkerWrapper = None # type: ignore


def ray_is_available() -> bool:
"""Returns True if Ray is available."""
return ray is not None


def assert_ray_available():
"""Raise an exception if Ray is not available."""
if ray is None:
raise ValueError("Failed to import Ray, please install Ray with "
"`pip install ray`.") from ray_import_err


def initialize_ray_cluster(
parallel_config: ParallelConfig,
ray_address: Optional[str] = None,
Expand All @@ -65,10 +77,7 @@ def initialize_ray_cluster(
ray_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
"""
if ray is None:
raise ImportError(
"Ray is not installed. Please install Ray to use multi-node "
"serving.")
assert_ray_available()

# Connect to a ray cluster.
if is_hip() or is_xpu():
Expand Down
Loading