Skip to content

Commit

Permalink
internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 711859223
  • Loading branch information
Pathways-on-Cloud Team authored and copybara-github committed Jan 3, 2025
1 parent a716b9b commit 0e8197b
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions pathwaysutils/test/google_internal/persistence_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Persistence tests that can only run in google3."""

from absl import flags
import jax

from google3.learning.pathways.ifrt.proxy.jax.tests import register_jax_grpc_backend_for_testing # pylint: disable=unused-import
from absl.testing import absltest


_JAX_BACKEND_TARGET = flags.DEFINE_string(
"jax_backend_target",
"ifrt_pathways",
"Jax backend target to use.",
)

_JAX_PLATFORMS = flags.DEFINE_string(
"jax_platforms",
"proxy",
"Jax platforms to use.",
)

# set JAX_ALLOW_UNUSED_TPUS to avoid the error below
#
# AssertionError: The host has 4 TPU chips
# but TPU support is not linked into JAX. You should add a BUILD dependency
# on //learning/brain/research/jax:tpu_support."
#
# This error happens because we are
# //learning/pathways/data_parallel:tpu_support instead of the more common
# //learning/brain/research/jax:tpu_support
flags.FLAGS.jax_allow_unused_tpus = True


class PersistenceTest(absltest.TestCase):

def test_devices_can_be_fetched_from_proxy_backend(self):
devices = jax.devices("proxy")
self.assertNotEmpty(devices)


if __name__ == "__main__":
absltest.main()

0 comments on commit 0e8197b

Please sign in to comment.