Skip to content

Commit 47480b4

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Add a set_mesh API to jax.sharding. set_mesh sets the sharding and never unsets it i.e. this is just __enter__ of a ctx manager without __exit__
PiperOrigin-RevId: 736261724
1 parent 8674495 commit 47480b4

File tree

6 files changed

+47
-13
lines changed

6 files changed

+47
-13
lines changed

jax/_src/mesh.py

-6
Original file line numberDiff line numberDiff line change
@@ -587,9 +587,3 @@ def set_concrete_mesh(mesh: Mesh | None):
587587

588588
def get_concrete_mesh():
589589
return jax_config.device_context.value
590-
591-
592-
@contextlib.contextmanager
593-
def use_mesh(mesh: Mesh):
594-
with set_abstract_mesh(mesh.abstract_mesh), set_concrete_mesh(mesh):
595-
yield

jax/_src/pjit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ def _infer_params(
690690
fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
691691
) -> tuple[PjitParams, list[Any]]:
692692
if ji.use_resource_env:
693-
with mesh_lib.use_mesh(mesh_lib.thread_resources.env.physical_mesh):
693+
with sharding_impls.use_mesh(mesh_lib.thread_resources.env.physical_mesh):
694694
return _infer_params_internal(fun, ji, args, kwargs)
695695
return _infer_params_internal(fun, ji, args, kwargs)
696696

jax/_src/sharding_impls.py

+26
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import collections
18+
import contextlib
1819
from collections.abc import Mapping, Sequence
1920
import dataclasses
2021
import functools
@@ -1410,3 +1411,28 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
14101411
allow_split_physical_axes=allow_split_physical_axes)
14111412
axis_types = _get_axis_types(auto_axes, explicit_axes, manual_axes)
14121413
return mesh_lib.Mesh(mesh_devices, axis_names, axis_types=axis_types)
1414+
1415+
1416+
@contextlib.contextmanager
1417+
def use_mesh(mesh: mesh_lib.Mesh):
1418+
if not isinstance(mesh, mesh_lib.Mesh):
1419+
raise ValueError(
1420+
f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}")
1421+
1422+
# TODO(yashkatariya): Enable this.
1423+
# if not core.trace_state_clean():
1424+
# raise ValueError('`use_mesh` can only be used outside of `jax.jit`')
1425+
1426+
with (mesh_lib.set_abstract_mesh(mesh.abstract_mesh),
1427+
mesh_lib.set_concrete_mesh(mesh)):
1428+
yield
1429+
1430+
def set_mesh(mesh: mesh_lib.Mesh) -> None:
1431+
if not isinstance(mesh, mesh_lib.Mesh):
1432+
raise ValueError(
1433+
f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}")
1434+
if not core.trace_state_clean():
1435+
raise ValueError('`set_mesh` can only be used outside of `jax.jit`.')
1436+
1437+
config.abstract_mesh_context_manager.set_local(mesh.abstract_mesh)
1438+
config.device_context.set_local(mesh)

jax/_src/test_util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1576,7 +1576,7 @@ def with_user_mesh(sizes, names, axis_types=None):
15761576
def decorator(fn):
15771577
def mesh_fn(*args, **kwargs):
15781578
mesh = create_mesh(sizes, names, axis_types=axis_types)
1579-
with mesh_lib.use_mesh(mesh):
1579+
with jax.sharding.use_mesh(mesh):
15801580
return fn(*args, **kwargs, mesh=mesh)
15811581
return mesh_fn
15821582
return decorator

jax/sharding.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
PmapSharding as PmapSharding,
2323
GSPMDSharding as GSPMDSharding,
2424
PositionalSharding as PositionalSharding,
25+
use_mesh as use_mesh,
26+
set_mesh as set_mesh,
2527
)
2628
from jax._src.partition_spec import (
2729
PartitionSpec as PartitionSpec,
@@ -30,7 +32,6 @@
3032
from jax._src.mesh import (
3133
AbstractMesh as AbstractMesh,
3234
AxisTypes as AxisTypes,
33-
use_mesh as use_mesh
3435
)
3536

3637
_deprecations = {

tests/pjit_test.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -5943,7 +5943,7 @@ def f(x, x2):
59435943
a = z @ x2
59445944
return a
59455945

5946-
with mesh_lib.use_mesh(mesh):
5946+
with jax.sharding.use_mesh(mesh):
59475947
out = f(arr, arr.T)
59485948
self.assertEqual(out.sharding, NamedSharding(mesh, P('x',)))
59495949
lowered_text = f.lower(arr, arr.T).as_text()
@@ -5952,7 +5952,7 @@ def f(x, x2):
59525952
mesh2 = jtu.create_mesh((2, 2), ('x', 'y'),
59535953
axis_types={mesh_lib.AxisTypes.Explicit: 'x',
59545954
mesh_lib.AxisTypes.Auto: 'y'})
5955-
with mesh_lib.use_mesh(mesh2):
5955+
with jax.sharding.use_mesh(mesh2):
59565956
arr = jax.device_put(arr, NamedSharding(mesh2, P('x', 'y')))
59575957
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh2, P('y', None)))
59585958
out = f(arr, arr2)
@@ -5966,7 +5966,7 @@ def f(x, x2):
59665966
mesh3 = jtu.create_mesh((2, 2), ('x', 'y'),
59675967
axis_types={mesh_lib.AxisTypes.Explicit: 'y',
59685968
mesh_lib.AxisTypes.Auto: 'x'})
5969-
with mesh_lib.use_mesh(mesh3):
5969+
with jax.sharding.use_mesh(mesh3):
59705970
arr = jax.device_put(arr, NamedSharding(mesh3, P('x', 'y')))
59715971
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh3, P(None, 'x')))
59725972
out = f(arr, arr2)
@@ -6143,7 +6143,7 @@ def test_inputs_different_context(self, mesh):
61436143
arr = jax.device_put(np_inp, s)
61446144

61456145
auto_mesh = jax.make_mesh((2,), 'x', auto_axes='x')
6146-
with mesh_lib.use_mesh(auto_mesh):
6146+
with jax.sharding.use_mesh(auto_mesh):
61476147
arr2 = jnp.ones(8)
61486148
self.assertDictEqual(arr2.sharding.mesh.axis_types,
61496149
{AxisTypes.Auto: ('x',)})
@@ -7083,6 +7083,19 @@ def f(x):
70837083
out = f(np.arange(8))
70847084
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
70857085

7086+
def test_set_mesh(self):
7087+
mesh = jtu.create_mesh((2,), ('x',), axis_types={AxisTypes.Explicit: 'x'})
7088+
prev_mesh = config.device_context.value
7089+
prev_abstract_mesh = config.abstract_mesh_context_manager.value
7090+
try:
7091+
jax.sharding.set_mesh(mesh)
7092+
7093+
out = reshard(np.arange(8), P('x'))
7094+
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
7095+
finally:
7096+
config.device_context.set_local(prev_mesh)
7097+
config.abstract_mesh_context_manager.set_local(prev_abstract_mesh)
7098+
70867099

70877100
@jtu.pytest_mark_if_available('multiaccelerator')
70887101
class PJitErrorTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)