Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions python/ray/train/v2/_internal/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@
WORKER_GROUP_START_TIMEOUT_S_ENV_VAR = "RAY_TRAIN_WORKER_GROUP_START_TIMEOUT_S"
DEFAULT_WORKER_GROUP_START_TIMEOUT_S: float = 30.0

# Timeout in seconds for `ray.train.report` to block on synchronization barriers,
# after which a timeout error will be raised.
REPORT_BARRIER_TIMEOUT_S_ENV_VAR = "RAY_TRAIN_REPORT_BARRIER_TIMEOUT_S"
DEFAULT_REPORT_BARRIER_TIMEOUT_S: float = 60 * 30
# Time in seconds for `ray.train.report` to log a warning if it is waiting for sync
# actor notification of releasing.
REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR = "RAY_TRAIN_REPORT_BARRIER_WARN_INTERVAL_S"
DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S: float = 60
# Time in seconds for collective operations before raising a timeout error.
COLLECTIVE_TIMEOUT_S_ENV_VAR = "RAY_TRAIN_COLLECTIVE_TIMEOUT_S"
# NOTE: Default to no timeout to avoid introducing more timeouts for users to configure.
# For example, users can already configure timeouts in torch distributed.
DEFAULT_COLLECTIVE_TIMEOUT_S: float = -1
# Interval in seconds to log a warning when waiting for a collective operation to complete.
COLLECTIVE_WARN_INTERVAL_S_ENV_VAR = "RAY_TRAIN_COLLECTIVE_WARN_INTERVAL_S"
DEFAULT_COLLECTIVE_WARN_INTERVAL_S: float = 60

# Environment variable to enable the print function patching.
ENABLE_PRINT_PATCH_ENV_VAR = "RAY_TRAIN_ENABLE_PRINT_PATCH"
Expand Down Expand Up @@ -98,8 +98,8 @@
HEALTH_CHECK_INTERVAL_S_ENV_VAR,
WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR,
WORKER_GROUP_START_TIMEOUT_S_ENV_VAR,
REPORT_BARRIER_TIMEOUT_S_ENV_VAR,
REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR,
COLLECTIVE_TIMEOUT_S_ENV_VAR,
COLLECTIVE_WARN_INTERVAL_S_ENV_VAR,
ENABLE_PRINT_PATCH_ENV_VAR,
ENABLE_CONTROLLER_STRUCTURED_LOGGING_ENV_VAR,
ENABLE_WORKER_STRUCTURED_LOGGING_ENV_VAR,
Expand Down
12 changes: 6 additions & 6 deletions python/ray/train/v2/_internal/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import List, Optional

from ray.train.v2._internal.constants import (
COLLECTIVE_TIMEOUT_S_ENV_VAR,
DEFAULT_WORKER_GROUP_START_TIMEOUT_S,
DEFAULT_WORKER_HEALTH_CHECK_TIMEOUT_S,
REPORT_BARRIER_TIMEOUT_S_ENV_VAR,
WORKER_GROUP_START_TIMEOUT_S_ENV_VAR,
WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR,
)
Expand Down Expand Up @@ -131,11 +131,11 @@ def __init__(
self._timeout_s = timeout_s

message = (
f"The broadcast operation timed out after {time_elapsed:.2f} seconds. "
"Please make sure all worker ranks call `ray.train.report`. \n"
f"The following ranks have not called it: {missing_ranks}\n"
f"You can set this timeout with the {REPORT_BARRIER_TIMEOUT_S_ENV_VAR} "
f"environment variable (current value: {timeout_s:.2f} s)."
f"The collective operation timed out after {time_elapsed:.2f} seconds. "
f"The following ranks have not joined the collective operation: {missing_ranks}\n"
f"You can set the timeout with the {COLLECTIVE_TIMEOUT_S_ENV_VAR} "
f"environment variable (current value: {timeout_s:.2f} seconds). "
"Disable the timeout by setting the environment variable to -1."
)
super().__init__(message)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import ray
from ray.train.v2._internal.constants import (
DEFAULT_REPORT_BARRIER_TIMEOUT_S,
DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S,
REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR,
COLLECTIVE_WARN_INTERVAL_S_ENV_VAR,
DEFAULT_COLLECTIVE_TIMEOUT_S,
DEFAULT_COLLECTIVE_WARN_INTERVAL_S,
)
from ray.train.v2._internal.exceptions import BroadcastCollectiveTimeoutError

Expand Down Expand Up @@ -35,8 +35,8 @@ class SynchronizationActor:

def __init__(
self,
timeout_s: float = DEFAULT_REPORT_BARRIER_TIMEOUT_S,
warn_interval_s: float = DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S,
timeout_s: float = DEFAULT_COLLECTIVE_TIMEOUT_S,
warn_interval_s: float = DEFAULT_COLLECTIVE_WARN_INTERVAL_S,
):
self._counter: int = 0
self._world_size: int = 0
Expand Down Expand Up @@ -139,7 +139,7 @@ async def _wait_with_logging(
world_size=self._world_size,
max_time_elapsed_s=self._get_time_elapsed(),
missing_ranks=self._get_missing_ranks(),
warn_interval_env_var=REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR,
warn_interval_env_var=COLLECTIVE_WARN_INTERVAL_S_ENV_VAR,
warn_interval_s=self._warn_interval_s,
),
)
Expand Down Expand Up @@ -189,7 +189,7 @@ async def broadcast_from_rank_zero(
self._wait_with_logging(
self._condition, world_rank, caller_method_name
),
timeout=self._timeout_s,
timeout=self._timeout_s if self._timeout_s >= 0 else None,
)
return self._reduced_data
except (asyncio.TimeoutError, TimeoutError) as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from ray.runtime_env import RuntimeEnv
from ray.train._internal.base_worker_group import BaseWorkerGroup
from ray.train.v2._internal.constants import (
DEFAULT_REPORT_BARRIER_TIMEOUT_S,
DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S,
COLLECTIVE_TIMEOUT_S_ENV_VAR,
COLLECTIVE_WARN_INTERVAL_S_ENV_VAR,
DEFAULT_COLLECTIVE_TIMEOUT_S,
DEFAULT_COLLECTIVE_WARN_INTERVAL_S,
DEFAULT_WORKER_GROUP_START_TIMEOUT_S,
DEFAULT_WORKER_HEALTH_CHECK_TIMEOUT_S,
REPORT_BARRIER_TIMEOUT_S_ENV_VAR,
REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR,
WORKER_GROUP_START_TIMEOUT_S_ENV_VAR,
WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR,
get_env_vars_to_propagate,
Expand Down Expand Up @@ -177,12 +177,12 @@ def __init__(
DEFAULT_WORKER_HEALTH_CHECK_TIMEOUT_S,
)
)
self._report_barrier_timeout_s = env_float(
REPORT_BARRIER_TIMEOUT_S_ENV_VAR, DEFAULT_REPORT_BARRIER_TIMEOUT_S
self._collective_timeout_s = env_float(
COLLECTIVE_TIMEOUT_S_ENV_VAR, DEFAULT_COLLECTIVE_TIMEOUT_S
)
self._report_barrier_warn_interval_s = env_float(
REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR,
DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S,
self._collective_warn_interval_s = env_float(
COLLECTIVE_WARN_INTERVAL_S_ENV_VAR,
DEFAULT_COLLECTIVE_WARN_INTERVAL_S,
)

################################################################################
Expand Down Expand Up @@ -309,8 +309,8 @@ def _start_impl(
soft=False,
)
).remote(
timeout_s=self._report_barrier_timeout_s,
warn_interval_s=self._report_barrier_warn_interval_s,
timeout_s=self._collective_timeout_s,
warn_interval_s=self._collective_warn_interval_s,
)
worker_group_state_builder.with_sync_actor(sync_actor)

Expand Down
47 changes: 41 additions & 6 deletions python/ray/train/v2/tests/test_sync_actor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

import ray
from ray.train.v2._internal.constants import DEFAULT_COLLECTIVE_TIMEOUT_S
from ray.train.v2._internal.exceptions import BroadcastCollectiveTimeoutError
from ray.train.v2._internal.execution.checkpoint.sync_actor import SynchronizationActor

Expand All @@ -14,10 +15,10 @@ def ray_start_4_cpus():

@pytest.mark.parametrize("world_size", [1, 10, 1000])
def test_broadcast_from_rank_0(world_size):
"""The test checks if all workers can reach a consensus on a data.
"""Check that rank 0 can broadcast data to all other workers.
Every worker sends data with a string "data-{rank}" that is unique
to the worker. Expected to get a consensus data of "data-0".
Also checks if the counter is reset to 0 after all workers have data.
to the worker. Everyone should receive the data from rank 0, which is "data-0".
Also assert that the actor state is reset after the broadcast function returns.
"""
sync_actor = SynchronizationActor.remote()
# Test broadcast_from_rank_zero with a world size of 10
Expand All @@ -39,7 +40,7 @@ def test_broadcast_from_rank_0(world_size):
assert ray.get(sync_actor.get_reduced_data.remote()) is None


def test_hang():
def test_hang_with_timeout():
"""The test checks if the workers are blocked and hang when the world size
is greater than the number of workers. The workers should block and hang
until the barrier is lifted.
Expand All @@ -61,13 +62,47 @@ def test_hang():
# after 1 second
with pytest.raises(BroadcastCollectiveTimeoutError) as excinfo:
ray.get(remote_tasks)
assert "The following ranks have not called it: [9]" in str(excinfo.value)
assert "The following ranks have not joined the collective operation: [9]" in str(
excinfo.value
)


def test_hang_without_timeout():
"""Test the default behavior of running with no collective timeout."""
assert DEFAULT_COLLECTIVE_TIMEOUT_S == -1

sync_actor = SynchronizationActor.remote()
remote_tasks = []
for rank in range(9):
remote_tasks.append(
sync_actor.broadcast_from_rank_zero.remote(
world_rank=rank,
world_size=10,
data=f"data-{rank}",
caller_method_name="broadcast_from_rank_zero",
)
)

# Just check for a short timeout to ensure the test doesn't error out.
done, _ = ray.wait(remote_tasks, num_returns=len(remote_tasks), timeout=2)
assert not done, "All tasks should be hanging, but some are done."

# Finish up once the last worker joins.
remote_tasks.append(
sync_actor.broadcast_from_rank_zero.remote(
world_rank=9,
world_size=10,
data="data-9",
caller_method_name="broadcast_from_rank_zero",
)
)
ray.get(remote_tasks)


def test_world_size_mismatch():
"""The test checks if the workers are blocked and raise an value error
when the world size is different. The workers should block and raise
an ValueError.
a ValueError.
"""
sync_actor = SynchronizationActor.remote()
remote_tasks = []
Expand Down