Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[runtime env] Fix Ray hangs when nonexistent conda environment is specified #28105 #34956

Merged
merged 13 commits into from
Aug 23, 2023
48 changes: 36 additions & 12 deletions python/ray/_private/runtime_env/conda.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
create_conda_env_if_needed,
delete_conda_env,
get_conda_activate_commands,
get_conda_env_list,
)
from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.packaging import Protocol, parse_uri
from ray._private.runtime_env.plugin import RuntimeEnvPlugin
from ray._private.runtime_env.validation import parse_and_validate_conda
from ray._private.utils import (
get_directory_size_bytes,
get_master_wheel_url,
Expand Down Expand Up @@ -218,11 +220,11 @@ def get_uri(runtime_env: Dict) -> Optional[str]:
conda = runtime_env.get("conda")
if conda is not None:
if isinstance(conda, str):
# User-preinstalled conda env. We don't garbage collect these, so
# we don't track them with URIs.
uri = None
elif isinstance(conda, dict):
uri = "conda://" + _get_conda_env_hash(conda_dict=conda)
# User-preinstalled conda env. We don't garbage collect these, so
# we don't track them with URIs.
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
uri = f"conda://{_get_conda_env_hash(conda_dict=conda)}"
else:
raise TypeError(
"conda field received by RuntimeEnvAgent must be "
Expand Down Expand Up @@ -269,6 +271,12 @@ def __init__(self, resources_dir: str):
self._installs_and_deletions_file_lock = os.path.join(
self._resources_dir, "ray-conda-installs-and-deletions.lock"
)
# A set of named conda environments (instead of yaml or dict)
# that are validated to exist.
# NOTE: It has to be only used within the same thread, which
# is an event loop.
# Also, we don't need to GC this field because it is pretty small.
self._validated_named_conda_env = set()

def _get_path_from_hash(self, hash: str) -> str:
"""Generate a path from the hash of a conda or pip spec.
Expand Down Expand Up @@ -319,18 +327,34 @@ async def create(
context: RuntimeEnvContext,
logger: logging.Logger = default_logger,
) -> int:
if uri is None:
# The "conda" field is the name of an existing conda env, so no
# need to create one.
# TODO(architkulkarni): Try "conda activate" here to see if the
# env exists, and raise an exception if it doesn't.
if not runtime_env.has_conda():
return 0

# Currently create method is still a sync process, to avoid blocking
# the loop, need to run this function in another thread.
# TODO(Catch-Bull): Refactor method create into an async process, and
# make this method running in current loop.
def _create():
result = parse_and_validate_conda(runtime_env.get("conda"))
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(result, str):
# The conda env name is given.
# In this case, we only verify if the given
# conda env exists.

# If the env is already validated, do nothing.
if result in self._validated_named_conda_env:
return 0

conda_env_list = get_conda_env_list()
envs = [Path(env).name for env in conda_env_list]
if result not in envs:
raise ValueError(
f"The given conda environment '{result}' "
f"from the runtime env {runtime_env} doesn't "
"exist from the output of `conda env list --json`. "
"You can only specify an env that already exists. "
f"Please make sure to create an env {result} "
)
self._validated_named_conda_env.add(result)
return 0

logger.debug(
"Setting up conda for runtime_env: " f"{runtime_env.serialize()}"
)
Expand Down
1 change: 1 addition & 0 deletions python/ray/_private/runtime_env/conda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def create_conda_env_if_needed(
"""
if logger is None:
logger = logging.getLogger(__name__)

conda_path = get_conda_bin_executable("conda")
try:
exec_cmd([conda_path, "--help"], throw_on_error=False)
Expand Down
32 changes: 32 additions & 0 deletions python/ray/tests/test_runtime_env_complicated.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,38 @@ def wrapped_version(self):
assert ray.get(actor.wrapped_version.remote()) == package_version


@pytest.mark.skipif(
os.environ.get("CONDA_DEFAULT_ENV") is None,
reason="must be run from within a conda environment",
)
def test_task_conda_env_validation_cached(shutdown_only):
"""Verify that when a task is running with the same conda env
it doesn't validate if env exists.
"""
# The first run would be slower because we need to validate
# if the package exists.
ray.init()
# version = EMOJI_VERSIONS[0]
# runtime_env = {"conda": f"package-{version}"}
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
runtime_env = {"conda": "core"}
task = get_emoji_version.options(runtime_env=runtime_env)
s = time.time()
ray.get(task.remote())
first_run = time.time() - s
# Typically takes 1~2 seconds.
print("First run took", first_run)

# We should verify this doesn't happen
# from the second run.
s = time.time()
for _ in range(10):
ray.get(task.remote())
second_10_runs = time.time() - s
# Typicall takes less than 100ms.
print("second 10 runs took", second_10_runs)
assert second_10_runs < first_run


@pytest.mark.skipif(
os.environ.get("CONDA_DEFAULT_ENV") is None,
reason="must be run from within a conda environment",
Expand Down
18 changes: 18 additions & 0 deletions python/ray/tests/test_runtime_env_conda_and_pip.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,24 @@ def f():
assert ray.get(f.remote()) == 0


def test_runtime_env_conda_not_exists_not_hang(shutdown_only):
"""Verify when the conda env doesn't exist, it doesn't hang Ray."""
ray.init(runtime_env={"conda": "env_which_does_not_exist"})

@ray.remote
def f():
return 1

refs = [f.remote() for _ in range(5)]

for ref in refs:
with pytest.raises(ray.exceptions.RuntimeEnvSetupError) as exc_info:
ray.get(ref)
assert "doesn't exist from the output of `conda env list --json`" in str(
exc_info.value
) # noqa


def test_get_requirements_file():
"""Unit test for _PathHelper.get_requirements_file."""
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down
Loading