diff --git a/python/ray/train/v2/_internal/constants.py b/python/ray/train/v2/_internal/constants.py index 08eba11bca8c..c5589ef2f591 100644 --- a/python/ray/train/v2/_internal/constants.py +++ b/python/ray/train/v2/_internal/constants.py @@ -43,14 +43,14 @@ WORKER_GROUP_START_TIMEOUT_S_ENV_VAR = "RAY_TRAIN_WORKER_GROUP_START_TIMEOUT_S" DEFAULT_WORKER_GROUP_START_TIMEOUT_S: float = 30.0 -# Timeout in seconds for `ray.train.report` to block on synchronization barriers, -# after which a timeout error will be raised. -REPORT_BARRIER_TIMEOUT_S_ENV_VAR = "RAY_TRAIN_REPORT_BARRIER_TIMEOUT_S" -DEFAULT_REPORT_BARRIER_TIMEOUT_S: float = 60 * 30 -# Time in seconds for `ray.train.report` to log a warning if it is waiting for sync -# actor notification of releasing. -REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR = "RAY_TRAIN_REPORT_BARRIER_WARN_INTERVAL_S" -DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S: float = 60 +# Time in seconds for collective operations before raising a timeout error. +COLLECTIVE_TIMEOUT_S_ENV_VAR = "RAY_TRAIN_COLLECTIVE_TIMEOUT_S" +# NOTE: Default to no timeout to avoid introducing more timeouts for users to configure. +# For example, users can already configure timeouts in torch distributed. +DEFAULT_COLLECTIVE_TIMEOUT_S: float = -1 +# Interval in seconds to log a warning when waiting for a collective operation to complete. +COLLECTIVE_WARN_INTERVAL_S_ENV_VAR = "RAY_TRAIN_COLLECTIVE_WARN_INTERVAL_S" +DEFAULT_COLLECTIVE_WARN_INTERVAL_S: float = 60 # Environment variable to enable the print function patching. ENABLE_PRINT_PATCH_ENV_VAR = "RAY_TRAIN_ENABLE_PRINT_PATCH" @@ -98,8 +98,8 @@ HEALTH_CHECK_INTERVAL_S_ENV_VAR, WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR, WORKER_GROUP_START_TIMEOUT_S_ENV_VAR, - REPORT_BARRIER_TIMEOUT_S_ENV_VAR, - REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR, + COLLECTIVE_TIMEOUT_S_ENV_VAR, + COLLECTIVE_WARN_INTERVAL_S_ENV_VAR, ENABLE_PRINT_PATCH_ENV_VAR, ENABLE_CONTROLLER_STRUCTURED_LOGGING_ENV_VAR, ENABLE_WORKER_STRUCTURED_LOGGING_ENV_VAR, diff --git a/python/ray/train/v2/_internal/exceptions.py b/python/ray/train/v2/_internal/exceptions.py index 4d5d62718b5d..592836f4f58e 100644 --- a/python/ray/train/v2/_internal/exceptions.py +++ b/python/ray/train/v2/_internal/exceptions.py @@ -2,9 +2,9 @@ from typing import List, Optional from ray.train.v2._internal.constants import ( + COLLECTIVE_TIMEOUT_S_ENV_VAR, DEFAULT_WORKER_GROUP_START_TIMEOUT_S, DEFAULT_WORKER_HEALTH_CHECK_TIMEOUT_S, - REPORT_BARRIER_TIMEOUT_S_ENV_VAR, WORKER_GROUP_START_TIMEOUT_S_ENV_VAR, WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR, ) @@ -131,11 +131,11 @@ def __init__( self._timeout_s = timeout_s message = ( - f"The broadcast operation timed out after {time_elapsed:.2f} seconds. " - "Please make sure all worker ranks call `ray.train.report`. \n" - f"The following ranks have not called it: {missing_ranks}\n" - f"You can set this timeout with the {REPORT_BARRIER_TIMEOUT_S_ENV_VAR} " - f"environment variable (current value: {timeout_s:.2f} s)." + f"The collective operation timed out after {time_elapsed:.2f} seconds. " + f"The following ranks have not joined the collective operation: {missing_ranks}\n" + f"You can set the timeout with the {COLLECTIVE_TIMEOUT_S_ENV_VAR} " + f"environment variable (current value: {timeout_s:.2f} seconds). " + "Disable the timeout by setting the environment variable to -1." ) super().__init__(message) diff --git a/python/ray/train/v2/_internal/execution/checkpoint/sync_actor.py b/python/ray/train/v2/_internal/execution/checkpoint/sync_actor.py index ee26ad967da7..914a30cca14d 100644 --- a/python/ray/train/v2/_internal/execution/checkpoint/sync_actor.py +++ b/python/ray/train/v2/_internal/execution/checkpoint/sync_actor.py @@ -5,9 +5,9 @@ import ray from ray.train.v2._internal.constants import ( - DEFAULT_REPORT_BARRIER_TIMEOUT_S, - DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S, - REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR, + COLLECTIVE_WARN_INTERVAL_S_ENV_VAR, + DEFAULT_COLLECTIVE_TIMEOUT_S, + DEFAULT_COLLECTIVE_WARN_INTERVAL_S, ) from ray.train.v2._internal.exceptions import BroadcastCollectiveTimeoutError @@ -35,8 +35,8 @@ class SynchronizationActor: def __init__( self, - timeout_s: float = DEFAULT_REPORT_BARRIER_TIMEOUT_S, - warn_interval_s: float = DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S, + timeout_s: float = DEFAULT_COLLECTIVE_TIMEOUT_S, + warn_interval_s: float = DEFAULT_COLLECTIVE_WARN_INTERVAL_S, ): self._counter: int = 0 self._world_size: int = 0 @@ -139,7 +139,7 @@ async def _wait_with_logging( world_size=self._world_size, max_time_elapsed_s=self._get_time_elapsed(), missing_ranks=self._get_missing_ranks(), - warn_interval_env_var=REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR, + warn_interval_env_var=COLLECTIVE_WARN_INTERVAL_S_ENV_VAR, warn_interval_s=self._warn_interval_s, ), ) @@ -189,7 +189,7 @@ async def broadcast_from_rank_zero( self._wait_with_logging( self._condition, world_rank, caller_method_name ), - timeout=self._timeout_s, + timeout=self._timeout_s if self._timeout_s >= 0 else None, ) return self._reduced_data except (asyncio.TimeoutError, TimeoutError) as e: diff --git a/python/ray/train/v2/_internal/execution/worker_group/worker_group.py b/python/ray/train/v2/_internal/execution/worker_group/worker_group.py index 098fa62e34aa..5e26276cb6f7 100644 --- a/python/ray/train/v2/_internal/execution/worker_group/worker_group.py +++ b/python/ray/train/v2/_internal/execution/worker_group/worker_group.py @@ -14,12 +14,12 @@ from ray.runtime_env import RuntimeEnv from ray.train._internal.base_worker_group import BaseWorkerGroup from ray.train.v2._internal.constants import ( - DEFAULT_REPORT_BARRIER_TIMEOUT_S, - DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S, + COLLECTIVE_TIMEOUT_S_ENV_VAR, + COLLECTIVE_WARN_INTERVAL_S_ENV_VAR, + DEFAULT_COLLECTIVE_TIMEOUT_S, + DEFAULT_COLLECTIVE_WARN_INTERVAL_S, DEFAULT_WORKER_GROUP_START_TIMEOUT_S, DEFAULT_WORKER_HEALTH_CHECK_TIMEOUT_S, - REPORT_BARRIER_TIMEOUT_S_ENV_VAR, - REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR, WORKER_GROUP_START_TIMEOUT_S_ENV_VAR, WORKER_HEALTH_CHECK_TIMEOUT_S_ENV_VAR, get_env_vars_to_propagate, @@ -177,12 +177,12 @@ def __init__( DEFAULT_WORKER_HEALTH_CHECK_TIMEOUT_S, ) ) - self._report_barrier_timeout_s = env_float( - REPORT_BARRIER_TIMEOUT_S_ENV_VAR, DEFAULT_REPORT_BARRIER_TIMEOUT_S + self._collective_timeout_s = env_float( + COLLECTIVE_TIMEOUT_S_ENV_VAR, DEFAULT_COLLECTIVE_TIMEOUT_S ) - self._report_barrier_warn_interval_s = env_float( - REPORT_BARRIER_WARN_INTERVAL_S_ENV_VAR, - DEFAULT_REPORT_BARRIER_WARN_INTERVAL_S, + self._collective_warn_interval_s = env_float( + COLLECTIVE_WARN_INTERVAL_S_ENV_VAR, + DEFAULT_COLLECTIVE_WARN_INTERVAL_S, ) ################################################################################ @@ -309,8 +309,8 @@ def _start_impl( soft=False, ) ).remote( - timeout_s=self._report_barrier_timeout_s, - warn_interval_s=self._report_barrier_warn_interval_s, + timeout_s=self._collective_timeout_s, + warn_interval_s=self._collective_warn_interval_s, ) worker_group_state_builder.with_sync_actor(sync_actor) diff --git a/python/ray/train/v2/tests/test_sync_actor.py b/python/ray/train/v2/tests/test_sync_actor.py index 7cbb04a3a6dd..9c1811f6a8fa 100644 --- a/python/ray/train/v2/tests/test_sync_actor.py +++ b/python/ray/train/v2/tests/test_sync_actor.py @@ -1,6 +1,7 @@ import pytest import ray +from ray.train.v2._internal.constants import DEFAULT_COLLECTIVE_TIMEOUT_S from ray.train.v2._internal.exceptions import BroadcastCollectiveTimeoutError from ray.train.v2._internal.execution.checkpoint.sync_actor import SynchronizationActor @@ -14,10 +15,10 @@ def ray_start_4_cpus(): @pytest.mark.parametrize("world_size", [1, 10, 1000]) def test_broadcast_from_rank_0(world_size): - """The test checks if all workers can reach a consensus on a data. + """Check that rank 0 can broadcast data to all other workers. Every worker sends data with a string "data-{rank}" that is unique - to the worker. Expected to get a consensus data of "data-0". - Also checks if the counter is reset to 0 after all workers have data. + to the worker. Everyone should receive the data from rank 0, which is "data-0". + Also assert that the actor state is reset after the broadcast function returns. """ sync_actor = SynchronizationActor.remote() # Test broadcast_from_rank_zero with a world size of 10 @@ -39,7 +40,7 @@ def test_broadcast_from_rank_0(world_size): assert ray.get(sync_actor.get_reduced_data.remote()) is None -def test_hang(): +def test_hang_with_timeout(): """The test checks if the workers are blocked and hang when the world size is greater than the number of workers. The workers should block and hang until the barrier is lifted. @@ -61,13 +62,47 @@ def test_hang(): # after 1 second with pytest.raises(BroadcastCollectiveTimeoutError) as excinfo: ray.get(remote_tasks) - assert "The following ranks have not called it: [9]" in str(excinfo.value) + assert "The following ranks have not joined the collective operation: [9]" in str( + excinfo.value + ) + + +def test_hang_without_timeout(): + """Test the default behavior of running with no collective timeout.""" + assert DEFAULT_COLLECTIVE_TIMEOUT_S == -1 + + sync_actor = SynchronizationActor.remote() + remote_tasks = [] + for rank in range(9): + remote_tasks.append( + sync_actor.broadcast_from_rank_zero.remote( + world_rank=rank, + world_size=10, + data=f"data-{rank}", + caller_method_name="broadcast_from_rank_zero", + ) + ) + + # Just check for a short timeout to ensure the test doesn't error out. + done, _ = ray.wait(remote_tasks, num_returns=len(remote_tasks), timeout=2) + assert not done, "All tasks should be hanging, but some are done." + + # Finish up once the last worker joins. + remote_tasks.append( + sync_actor.broadcast_from_rank_zero.remote( + world_rank=9, + world_size=10, + data="data-9", + caller_method_name="broadcast_from_rank_zero", + ) + ) + ray.get(remote_tasks) def test_world_size_mismatch(): """The test checks if the workers are blocked and raise an value error when the world size is different. The workers should block and raise - an ValueError. + a ValueError. """ sync_actor = SynchronizationActor.remote() remote_tasks = []