Skip to content

Commit a341722

Browse files
committed
Add callback and fix other comments
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
1 parent 360b952 commit a341722

File tree

9 files changed

+271
-159
lines changed

9 files changed

+271
-159
lines changed

python/ray/_private/accelerators/tpu.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import ray
1111
from ray._private.accelerators.accelerator import AcceleratorManager
12-
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
1312

1413
logger = logging.getLogger(__name__)
1514

@@ -128,24 +127,6 @@ def infer_tpu_pod_type_from_topology(
128127
return None
129128

130129

131-
def fetch_tpu_slice_name_from_pg(pg):
132-
@ray.remote(num_cpus=0)
133-
def _get_tpu_slice_name():
134-
import ray
135-
136-
return (
137-
ray._private.accelerators.TPUAcceleratorManager.get_current_node_tpu_name()
138-
)
139-
140-
tpu_name_ref = _get_tpu_slice_name.options(
141-
scheduling_strategy=PlacementGroupSchedulingStrategy(
142-
placement_group=pg, placement_group_bundle_index=0
143-
)
144-
).remote()
145-
146-
return ray.get(tpu_name_ref)
147-
148-
149130
class TPUAcceleratorManager(AcceleratorManager):
150131
"""Google TPU accelerators."""
151132

python/ray/train/v2/BUILD

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,23 @@ py_test(
488488
py_test(
489489
name = "test_jax_trainer",
490490
size = "small",
491-
srcs = ["tests/test_xgboost_trainer.py"],
491+
srcs = ["tests/test_jax_trainer.py"],
492+
env = {"RAY_TRAIN_V2_ENABLED": "1"},
493+
tags = [
494+
"exclusive",
495+
"team:ml",
496+
"train_v2",
497+
],
498+
deps = [
499+
":conftest",
500+
"//:ray_lib",
501+
],
502+
)
503+
504+
py_test(
505+
name = "test_tpu_utils",
506+
size = "small",
507+
srcs = ["tests/test_tpu_utils.py"],
492508
env = {"RAY_TRAIN_V2_ENABLED": "1"},
493509
tags = [
494510
"exclusive",

python/ray/train/v2/_internal/callbacks/accelerators.py

Lines changed: 2 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,19 @@
11
import logging
22
import os
33
from collections import defaultdict
4-
from typing import List, Optional
4+
from typing import List
55

6-
import ray
76
import ray._private.ray_constants as ray_constants
87
from ray._private.accelerators.nvidia_gpu import CUDA_VISIBLE_DEVICES_ENV_VAR
9-
from ray._private.accelerators.tpu import (
10-
fetch_tpu_slice_name_from_pg,
11-
infer_tpu_pod_type_from_topology,
12-
)
13-
from ray._private.ray_constants import env_bool, env_integer
8+
from ray._private.ray_constants import env_bool
149
from ray.train import BackendConfig
1510
from ray.train.constants import (
1611
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
17-
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
1812
)
1913
from ray.train.v2._internal.execution.callback import WorkerGroupCallback
2014
from ray.train.v2._internal.execution.worker_group import ActorMetadata, WorkerGroup
2115
from ray.train.v2._internal.util import ray_get_safe
2216
from ray.train.v2.api.config import ScalingConfig
23-
from ray.util.placement_group import (
24-
PlacementGroup,
25-
)
2617

2718
logger = logging.getLogger(__name__)
2819

@@ -161,105 +152,3 @@ def _get_visible_accelerator_ids_per_worker(
161152
visible_accelerator_ids_per_worker.append(all_resource_ids)
162153

163154
return visible_accelerator_ids_per_worker
164-
165-
166-
def reserve_tpu_slice(
167-
num_workers: int,
168-
resources_per_worker: dict,
169-
topology: Optional[str],
170-
accelerator_type: Optional[str],
171-
) -> Optional[PlacementGroup]:
172-
"""Creates a SPMD-aware placement group. This currently only supports
173-
TPU with JaxTrainer by reserving a multi-host slice.
174-
175-
This creates a head PG (for index 0) that reserves the `TPU-{}-head` resource
176-
on the node, retrieves unique slice information from it, and then creates a
177-
multi-host slice PG (for index 0..N-1) that reserves the `TPU` resource on all
178-
the nodes in the slice. This enables atomic scheduling of TPU workers.
179-
180-
Args:
181-
num_workers: Total number of workers to launch.
182-
resources_per_worker: Resource requirements per bundle (e.g. {"CPU": 4}).
183-
topology: The TPU topology string (e.g. "2x2x2").
184-
accelerator_type: The accelerator type of the node (e.g. "TPU-V4").
185-
186-
Returns:
187-
A PlacementGroup if able to be created, or None.
188-
"""
189-
if not (topology and accelerator_type):
190-
return None
191-
192-
pod_type = infer_tpu_pod_type_from_topology(topology, accelerator_type)
193-
if pod_type is None:
194-
return None
195-
196-
# Reserve a slice by creating a placement group on the
197-
# TPU head.
198-
head_label_selector = {
199-
"ray.io/tpu-worker-id": "0",
200-
"ray.io/tpu-pod-type": pod_type,
201-
}
202-
head_placement_group = ray.util.placement_group(
203-
bundles=[{f"TPU-{pod_type}-head": 1}],
204-
bundle_label_selector=[head_label_selector],
205-
)
206-
207-
logger.debug("Waiting to reserve multi-host slice head.")
208-
timeout = env_integer(TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV, 100)
209-
ready, _ = ray.wait([head_placement_group.ready()], timeout=timeout)
210-
211-
if not ready:
212-
raise TimeoutError(
213-
"Failed to reserve TPU head for slice with shape: {}. "
214-
"Ensure your cluster has sufficient resources. Requesting TPU "
215-
"head node with labels: {}. Current resources: {}".format(
216-
pod_type, head_label_selector, ray.available_resources()
217-
)
218-
)
219-
220-
if num_workers == 1:
221-
logger.debug("Reserved single-host TPU placement group.")
222-
return head_placement_group
223-
224-
# Retrieve the unique slice ID.
225-
slice_name = fetch_tpu_slice_name_from_pg(head_placement_group)
226-
if slice_name is None:
227-
raise RuntimeError(
228-
"Failed to retrieve TPU slice name after reserving head placement group. "
229-
"Ensure that TPU slice metadata is available and correctly configured on multi-host nodes."
230-
)
231-
slice_label_selector = {
232-
"ray.io/tpu-slice-name": slice_name,
233-
"ray.io/tpu-pod-type": pod_type,
234-
}
235-
236-
# Schedule the remaining multi-host workers together with the head bundle.
237-
slice_placement_group = ray.util.placement_group(
238-
bundles=[resources_per_worker] * num_workers,
239-
bundle_label_selector=[slice_label_selector] * num_workers,
240-
strategy="SPREAD",
241-
)
242-
logger.debug("Waiting for multi-host slice placement group to start.")
243-
timeout = env_integer(TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV, 100)
244-
ready, _ = ray.wait([slice_placement_group.ready()], timeout=timeout)
245-
246-
if ready:
247-
logger.debug("SPMD placement groups have started.")
248-
else:
249-
ray.util.remove_placement_group(head_placement_group)
250-
ray.util.remove_placement_group(slice_placement_group)
251-
raise TimeoutError(
252-
"SPMD Placement group creation timed out. Make sure your "
253-
"cluster either has enough resources or use an "
254-
"autoscaling cluster. Ensure your cluster has multi-host nodes "
255-
"available for SPMD scheduling. "
256-
"Current resources available: {}, resources requested by the "
257-
"placement groups: {} with labels {}".format(
258-
ray.available_resources(),
259-
[resources_per_worker] * num_workers,
260-
slice_label_selector,
261-
)
262-
)
263-
ray.util.remove_placement_group(head_placement_group)
264-
265-
return slice_placement_group
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from typing import Dict, Optional
2+
3+
import ray
4+
from ray.train.v2._internal.execution.callback import ControllerCallback
5+
from ray.train.v2.api.config import ScalingConfig
6+
from ray.train.v2.jax.tpu_utils import reserve_tpu_slice
7+
8+
9+
class TPUReservationCallback(ControllerCallback):
10+
"""A callback to handle TPU slice reservation for multi-host training."""
11+
12+
def on_controller_start_worker_group(
13+
self, *, scaling_config: ScalingConfig, num_workers: int
14+
) -> Optional[Dict[str, str]]:
15+
"""Reserves a multi-host TPU slice before the worker group starts.
16+
17+
This hook is called by the TrainController. It checks if multi-host
18+
TPUs are being used and, if so, reserves a slice.
19+
20+
Args:
21+
scaling_config: The scaling configuration for the run.
22+
num_workers: The number of workers to be started.
23+
24+
Returns:
25+
A dictionary defining a `bundle_label_selector` to gang schedule
26+
the worker group on the reserved TPU slice.
27+
"""
28+
bundle_label_selector = None
29+
30+
if getattr(scaling_config, "use_tpu", False) and num_workers > 1:
31+
slice_name = reserve_tpu_slice(
32+
topology=getattr(scaling_config, "topology", None),
33+
accelerator_type=getattr(scaling_config, "accelerator_type", None),
34+
)
35+
if not slice_name:
36+
raise RuntimeError("Failed to reserve TPU slice.")
37+
38+
bundle_label_selector = {
39+
ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY: slice_name
40+
}
41+
42+
return bundle_label_selector

python/ray/train/v2/_internal/execution/callback.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import TYPE_CHECKING, Any, Dict, List, Optional
33

44
from ray.train.v2.api.callback import RayTrainCallback
5+
from ray.train.v2.api.config import ScalingConfig
56
from ray.train.v2.api.result import Result
67
from ray.util.annotations import DeveloperAPI
78

@@ -78,6 +79,24 @@ def after_controller_start(self, train_run_context: "TrainRunContext"):
7879
before the control loop starts executing."""
7980
pass
8081

82+
def on_controller_start_worker_group(
83+
self, *, scaling_config: ScalingConfig, num_workers: int
84+
) -> Optional[Dict[str, str]]:
85+
"""Called by the TrainController before the worker group is started.
86+
87+
This hook can be used to perform setup that modifies the worker group's
88+
placement, such as reserving an accelerator slice.
89+
90+
Args:
91+
scaling_config: The scaling configuration for the run.
92+
num_workers: The number of workers to be started.
93+
94+
Returns:
95+
An optional dictionary defining a `bundle_label_selector`
96+
to gang schedule the worker group on the reserved TPU slice.
97+
"""
98+
return None
99+
81100
def before_controller_shutdown(self):
82101
"""Called before `TrainController.run` exits,
83102
after the control loop has exited."""

python/ray/train/v2/_internal/execution/controller/controller.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99

1010
import ray
1111
import ray._private.ray_constants as ray_constants
12-
from ray.train.v2._internal.callbacks.accelerators import (
13-
reserve_tpu_slice,
14-
)
1512
from ray.train.v2._internal.constants import (
1613
DEFAULT_ENABLE_CONTROLLER_LOGGING,
1714
DEFAULT_HEALTH_CHECK_INTERVAL_S,
@@ -283,27 +280,29 @@ def _start_worker_group(
283280
ControllerError if the worker group failed to start.
284281
"""
285282
placement_strategy = self._scaling_policy.scaling_config.placement_strategy
286-
placement_group = None
287-
backend_config = self._train_run_context.backend_config
288-
289-
if getattr(backend_config, "use_tpu", False):
290-
try:
291-
placement_group = reserve_tpu_slice(
292-
num_workers=num_workers,
293-
resources_per_worker=resources_per_worker,
294-
topology=getattr(backend_config, "topology", None),
295-
accelerator_type=getattr(backend_config, "accelerator_type", None),
296-
)
297-
except Exception as e:
298-
return ControllerError(e)
283+
scaling_config = self._train_run_context.scaling_config
284+
285+
# Check for `bundle_label_selector` to influence WorkerGroup scheduling.
286+
bundle_label_selector = None
287+
try:
288+
for callback in self._callbacks:
289+
if hasattr(callback, "on_controller_start_worker_group"):
290+
selector = callback.on_controller_start_worker_group(
291+
scaling_config=scaling_config, num_workers=num_workers
292+
)
293+
if selector:
294+
bundle_label_selector = selector
295+
break
296+
except Exception as e:
297+
return ControllerError(e)
299298

300299
worker_group_context = WorkerGroupContext(
301300
run_attempt_id=self._get_run_attempt_id(),
302301
train_fn_ref=self._train_fn_ref,
303302
num_workers=num_workers,
304303
resources_per_worker=resources_per_worker,
305304
placement_strategy=placement_strategy,
306-
placement_group=placement_group,
305+
bundle_label_selector=bundle_label_selector,
307306
)
308307
try:
309308
self._worker_group = self.worker_group_cls.create(

python/ray/train/v2/_internal/execution/worker_group/worker_group.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,15 @@ class WorkerGroupContext:
8989
num_workers: The number of workers in the worker group.
9090
resources_per_worker: The resources per worker.
9191
placement_strategy: Strategy for placing workers.
92-
placement_group: Optional override placement group to schedule workers to.
92+
bundle_label_selector: Optional label selectors to apply per-bundle for workers.
9393
"""
9494

9595
run_attempt_id: str
9696
train_fn_ref: ObjectRefWrapper[Callable[[], None]]
9797
num_workers: int
9898
resources_per_worker: Dict[str, float]
9999
placement_strategy: str = "PACK"
100-
placement_group: Optional[PlacementGroup] = None
100+
bundle_label_selector: Optional[Dict[str, str]] = None
101101

102102

103103
class WorkerGroup:
@@ -255,7 +255,6 @@ def _start_impl(
255255
"""
256256
self._assert_inactive()
257257
worker_group_context = self._worker_group_context
258-
pg = worker_group_context.placement_group
259258

260259
WorkerGroup._check_cluster_resources_and_raise_if_insufficient(
261260
worker_group_context.resources_per_worker,
@@ -271,12 +270,19 @@ def _start_impl(
271270
for callback in self._callbacks:
272271
callback.before_worker_group_start(worker_group_context)
273272

274-
if pg is None:
275-
pg = placement_group(
276-
bundles=[worker_group_context.resources_per_worker]
277-
* worker_group_context.num_workers,
278-
strategy=worker_group_context.placement_strategy,
279-
)
273+
bundle_label_selector = (
274+
[worker_group_context.bundle_label_selector.copy()]
275+
* worker_group_context.num_workers
276+
if worker_group_context.bundle_label_selector
277+
else None
278+
)
279+
280+
pg = placement_group(
281+
bundles=[worker_group_context.resources_per_worker]
282+
* worker_group_context.num_workers,
283+
strategy=worker_group_context.placement_strategy,
284+
bundle_label_selector=bundle_label_selector,
285+
)
280286
logger.info(
281287
f"Attempting to start training worker group of size {worker_group_context.num_workers} with "
282288
f"the following resources: [{worker_group_context.resources_per_worker}] * {worker_group_context.num_workers}"

0 commit comments

Comments
 (0)