Skip to content

Commit

Permalink
Access thread_resources via jax.interpreters.pxla instead of jax.expe…
Browse files Browse the repository at this point in the history
…rimental.maps

The maps submodule is deprecated and will be removed soon.

PiperOrigin-RevId: 617511370
  • Loading branch information
superbobry authored and Flax Authors committed Mar 20, 2024
1 parent 006da41 commit 3b401ee
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
8 changes: 4 additions & 4 deletions flax/core/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from flax import errors, struct
from flax.typing import LogicalNames
import jax
from jax.experimental import maps
from jax.interpreters import pxla

A = TypeVar('A')
B = TypeVar('B')
Expand Down Expand Up @@ -178,9 +178,9 @@ def inner_update(c, v):


def _global_mesh_defined() -> bool:
"""Checks if global xmap/pjit mesh resource environment is defined."""
maps_env = maps.thread_resources.env
return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison
"""Checks if global mesh resource environment is defined."""
env = pxla.thread_resources.env
return env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison


class Partitioned(struct.PyTreeNode, AxisMetadata[A]):
Expand Down
8 changes: 4 additions & 4 deletions flax/experimental/nnx/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import typing as tp

import jax
from jax.experimental import maps
from jax.interpreters import pxla
from jax.sharding import Mesh, PartitionSpec

from flax.experimental.nnx.nnx import variables
Expand Down Expand Up @@ -126,9 +126,9 @@ def get_named_sharding(tree: A, mesh: jax.sharding.Mesh) -> A:


def _global_mesh_defined() -> bool:
"""Checks if global xmap/pjit mesh resource environment is defined."""
maps_env = maps.thread_resources.env
return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison
"""Checks if global mesh resource environment is defined."""
env = pxla.thread_resources.env
return env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison


def _with_sharding_constraint(
Expand Down
8 changes: 4 additions & 4 deletions flax/linen/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

import jax
from jax import lax
from jax.experimental import maps
from jax.interpreters import pxla

from flax import struct
from flax.core import meta
Expand Down Expand Up @@ -209,9 +209,9 @@ def logical_to_mesh_sharding(


def _global_mesh_defined() -> bool:
"""Checks if global xmap/jit mesh resource environment is defined."""
maps_env = maps.thread_resources.env
return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison
"""Checks if global mesh resource environment is defined."""
env = pxla.thread_resources.env
return env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison


class RulesFallback(enum.Enum):
Expand Down

0 comments on commit 3b401ee

Please sign in to comment.