Skip to content

Commit 3b0e3fa

Browse files
cpgaffney1Orbax Authors
authored andcommitted
Add utilities for working with local steps on Pathways.
PiperOrigin-RevId: 834320290
1 parent fa18d5f commit 3b0e3fa

File tree

3 files changed

+96
-37
lines changed

3 files changed

+96
-37
lines changed

checkpoint/orbax/checkpoint/_src/multihost/multihost.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def is_pathways_backend() -> bool:
6767
)
6868

6969

70+
def is_pathways_controller() -> bool:
71+
return jax.local_devices()[0].client.runtime_type == 'pathways'
72+
73+
7074
def is_runtime_to_distributed_ids_initialized() -> bool:
7175
return _RUNTIME_TO_DISTRIBUTED_ID is not None
7276

checkpoint/orbax/checkpoint/_src/multihost/pathways.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,27 @@
1616

1717
import functools
1818
import jax
19-
import numpy as np
20-
from .learning.deepmind.jax.ocean.remote_python import rp
2119

2220

2321
@functools.lru_cache(maxsize=1)
24-
def worker_count() -> int:
25-
"""Gets the number of Pathways workers."""
26-
fully_replicated_sharding = jax.sharding.NamedSharding(
27-
jax.sharding.Mesh(jax.devices(), 'x'),
28-
jax.sharding.PartitionSpec(),
29-
)
22+
def worker_count(global_mesh: jax.sharding.Mesh | None) -> int:
23+
"""Gets the number of Pathways workers.
3024
31-
@rp.stateless_fn
32-
def _get_worker_count(_) -> jax.Array:
33-
wc = np.asarray(jax.process_count(), dtype=np.int32)
34-
return jax.make_array_from_callback(
35-
(),
36-
fully_replicated_sharding,
37-
lambda _: wc,
38-
dtype=np.int32,
39-
)
25+
Args:
26+
global_mesh: The global mesh of active devices. If None is provided,
27+
`jax.devices()` will be used.
4028
41-
dummy_input = jax.device_put(
42-
np.asarray(0, dtype=np.int32),
43-
device=fully_replicated_sharding,
44-
)
45-
_get_worker_count.register_shape_fn(
46-
lambda _: jax.ShapeDtypeStruct(
47-
(), dtype=np.int32, sharding=fully_replicated_sharding
48-
)
49-
)
50-
result = _get_worker_count(rp.to_remote_python(dummy_input))
51-
jax.block_until_ready(result)
52-
result = rp.from_remote_python(result)
53-
return result.item()
29+
Returns:
30+
The number of Pathways workers in the mesh.
31+
"""
32+
global_mesh = global_mesh or jax.sharding.Mesh(jax.devices(), 'x')
33+
devices = global_mesh.devices.flatten()
34+
workers = set()
35+
for d in devices:
36+
attrs = []
37+
if hasattr(d, 'virtual_task_index'):
38+
attrs.append(d.virtual_task_index)
39+
if hasattr(d, 'slice_index'):
40+
attrs.append(d.slice_index)
41+
workers.add(tuple(attrs))
42+
return len(workers)

checkpoint/orbax/checkpoint/testing/local_path.py

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,44 @@
1818

1919
import os
2020
import typing
21-
from typing import Iterator
21+
from typing import Iterator, Sequence
2222

2323
from etils import epath
2424
from orbax.checkpoint._src.multihost import multihost
2525

2626

27+
_LOCAL_PATH_BASE_NAME = '_local_path_base'
28+
_LOCAL_PART_PREFIX = 'local'
29+
30+
2731
# The following is a hack to pass the type checker.
2832
if typing.TYPE_CHECKING:
2933
_BasePath = epath.Path
3034
else:
3135
_BasePath = object
3236

3337

38+
def create_local_path_base(testclass) -> epath.Path:
39+
return epath.Path(
40+
testclass.create_tempdir(name=_LOCAL_PATH_BASE_NAME).full_path
41+
)
42+
43+
44+
def _get_local_part_index(parts: Sequence[str]) -> int:
45+
for i, part in enumerate(parts):
46+
if part.startswith(_LOCAL_PART_PREFIX):
47+
return i
48+
raise ValueError(
49+
f'Did not find a local part ({_LOCAL_PART_PREFIX}) in parts: {parts}'
50+
)
51+
52+
3453
class LocalPath(_BasePath):
3554
"""A Path implementation for testing process-local paths.
3655
56+
IMPORTANT: Use `create_local_path_base` to create the base path for test
57+
cases.
58+
3759
In the future, this class may more completely provide all functions and
3860
properties of a pathlib Path, but for now, it only provides the minimum
3961
needed to support relevant tests.
@@ -54,9 +76,12 @@ class LocalPath(_BasePath):
5476

5577
def __init__(self, *parts: epath.PathLike):
5678
self._path = epath.Path('/'.join(os.fspath(p) for p in parts))
79+
# Assumes this class will always be constructed on the controller first
80+
# (otherwise this check will return the wrong value).
81+
self._is_pathways_backend = multihost.is_pathways_backend()
5782

5883
def __repr__(self) -> str:
59-
return f'{self.__class__.__name__}({self.path})'
84+
return f'LocalPath({self.path})'
6085

6186
def __str__(self) -> str:
6287
return str(self.path)
@@ -67,7 +92,40 @@ def base_path(self) -> epath.Path:
6792

6893
@property
6994
def path(self) -> epath.Path:
70-
return self.base_path / str(f'local_{multihost.process_index()}')
95+
parts = list(self.base_path.parts)
96+
97+
# Fail if the path is not properly configured. The local part should be
98+
# immediately following the base name.
99+
try:
100+
base_idx = parts.index(_LOCAL_PATH_BASE_NAME)
101+
except ValueError as e:
102+
raise ValueError(
103+
f'Base path for LocalPath must contain {_LOCAL_PATH_BASE_NAME}. Got:'
104+
f' {self.base_path}'
105+
) from e
106+
107+
if multihost.is_pathways_controller():
108+
local_part = f'{_LOCAL_PART_PREFIX}_controller'
109+
else:
110+
local_part = f'{_LOCAL_PART_PREFIX}_{multihost.process_index()}'
111+
112+
try:
113+
# If the local part is already present, potentially replace it with the
114+
# correct local part (e.g. controller vs worker).
115+
local_part_idx = _get_local_part_index(parts)
116+
assert local_part_idx == base_idx + 1
117+
parts[local_part_idx] = local_part
118+
return epath.Path(*parts)
119+
except ValueError:
120+
pass
121+
122+
# Otherwise, insert following the base part.
123+
parts.insert(base_idx + 1, local_part)
124+
return epath.Path(*parts)
125+
126+
@property
127+
def parts(self) -> tuple[str, ...]:
128+
return self.path.parts
71129

72130
def exists(self) -> bool:
73131
"""Returns True if self exists."""
@@ -119,6 +177,14 @@ def unlink(self, missing_ok: bool = False) -> None:
119177
"""Remove this file or symbolic link."""
120178
self.path.unlink(missing_ok=missing_ok)
121179

180+
def touch(self, mode: int = 0o666, exist_ok: bool = False) -> None:
181+
"""Creates the file at this path."""
182+
self.path.touch(exist_ok=exist_ok)
183+
184+
def rename(self, new_path: epath.PathLike) -> None:
185+
"""Renames this file or directory to the given path."""
186+
self.path.rename(new_path)
187+
122188
def write_bytes(self, data: bytes) -> int:
123189
"""Writes content as bytes."""
124190
return self.path.write_bytes(data)
@@ -135,16 +201,16 @@ def write_text(
135201
def as_posix(self) -> str:
136202
return self.path.as_posix()
137203

138-
def __truediv__(self, key: epath.PathLike) -> epath.Path:
139-
return self.path / key
204+
def __truediv__(self, key: epath.PathLike) -> LocalPath:
205+
return LocalPath(self.path / key)
140206

141207
@property
142208
def name(self) -> str:
143209
return self.path.name
144210

145211
@property
146-
def parent(self) -> epath.Path:
147-
return self.path.parent
212+
def parent(self) -> LocalPath:
213+
return LocalPath(self.path.parent)
148214

149215
def __fspath__(self) -> str:
150216
return os.fspath(self.path)

0 commit comments

Comments
 (0)