Skip to content

Commit

Permalink
Merge pull request jax-ml#22481 from zhenying-liu:offloading
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657413977
  • Loading branch information
jax authors committed Jul 30, 2024
2 parents 3003754 + c774d7b commit cc21245
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
17 changes: 15 additions & 2 deletions tests/layout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ def tearDownModule():
class LayoutTest(jtu.JaxTestCase):

def setUp(self):
if not jtu.test_device_matches(['tpu']):
self.skipTest("Layouts do not work on CPU and GPU backends yet.")
if not jtu.test_device_matches(['tpu', 'gpu']):
self.skipTest("Layouts do not work on CPU backend yet.")
super().setUp()

def test_auto_layout(self):
if jtu.test_device_matches(["gpu"]):
self.skipTest("This test does not work on GPU backend.")
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
shape1 = (128, 128)
shape2 = (128, 128)
Expand Down Expand Up @@ -110,6 +112,8 @@ def init(x, y):
self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T)

def test_default_layout(self):
if jtu.test_device_matches(["gpu"]):
self.skipTest("This test does not work on GPU backend.")
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
shape = (4, 4, 2)
np_inp = np.arange(math.prod(shape)).reshape(shape)
Expand Down Expand Up @@ -149,6 +153,8 @@ def f(x):
out_shardings=DLL.AUTO).lower(sds).compile()

def test_in_layouts_out_layouts(self):
if jtu.test_device_matches(["gpu"]):
self.skipTest("This test does not work on GPU backend.")
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
shape = (8, 8)
np_inp = np.arange(math.prod(shape)).reshape(shape)
Expand All @@ -173,6 +179,8 @@ def f(x):
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x')))

def test_sharding_and_layouts(self):
if jtu.test_device_matches(["gpu"]):
self.skipTest("This test does not work on GPU backend.")
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
shape = (4, 8)
np_inp = np.arange(math.prod(shape)).reshape(shape)
Expand Down Expand Up @@ -235,6 +243,8 @@ def f(x, y):
compiled(*arrs)

def test_aot_layout_mismatch(self):
if jtu.test_device_matches(["gpu"]):
self.skipTest("This test does not work on GPU backend.")
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
shape = (256, 4, 2)
np_inp = np.arange(math.prod(shape)).reshape(shape)
Expand Down Expand Up @@ -404,6 +414,9 @@ def f(x):
self.assertArraysEqual(out, inp.T)

def test_device_put_user_concrete_layout(self):
if jtu.test_device_matches(["gpu"]):
self.skipTest("This test does not work on GPU backend.")

shape = (8, 128)
np_inp = np.arange(math.prod(shape)).reshape(shape)
dll = DLL(major_to_minor=(1, 0))
Expand Down
28 changes: 26 additions & 2 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def test_default_memory_kind(self):
class DevicePutTest(jtu.JaxTestCase):

def setUp(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Memories do not work on CPU and GPU backends yet.")
if not jtu.test_device_matches(["tpu", "gpu"]):
self.skipTest("Memories do not work on CPU backend yet.")
super().setUp()

def _check_device_put_addressable_shards(
Expand All @@ -215,6 +215,8 @@ def test_error_transfer_to_memory_kind_outside_jit(self):

@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_host_to_hbm(self, host_memory_kind: str):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind)
np_inp = np.arange(16).reshape(8, 2)
Expand All @@ -229,6 +231,8 @@ def test_device_put_host_to_hbm(self, host_memory_kind: str):

@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_hbm_to_host(self, host_memory_kind: str):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind)
inp = jnp.arange(16).reshape(8, 2)
Expand All @@ -246,6 +250,8 @@ def test_device_put_hbm_to_host(self, host_memory_kind: str):
def test_device_put_different_device_and_memory_host_to_hbm(
self, host_memory_kind: str
):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
if jax.device_count() < 3:
raise unittest.SkipTest("Test requires >=3 devices")

Expand All @@ -266,6 +272,8 @@ def test_device_put_different_device_and_memory_host_to_hbm(
def test_device_put_different_device_and_memory_hbm_to_host(
self, host_memory_kind: str
):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
if jax.device_count() < 3:
raise unittest.SkipTest("Test requires >=3 devices")

Expand All @@ -285,6 +293,8 @@ def test_device_put_different_device_and_memory_hbm_to_host(
@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_on_different_device_with_the_same_memory_kind(
self, host_memory_kind: str):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
if len(jax.devices()) < 2:
raise unittest.SkipTest("Test requires >=2 devices.")

Expand Down Expand Up @@ -331,6 +341,8 @@ def test_device_put_on_different_device_with_the_same_memory_kind(

@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_numpy_array(self, host_memory_kind: str):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
np_inp = np.arange(16).reshape(8, 2)
s_hbm = NamedSharding(mesh, P(("x", "y")), memory_kind="device")
Expand All @@ -345,6 +357,8 @@ def test_device_put_numpy_array(self, host_memory_kind: str):

@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_numpy_scalar(self, host_memory_kind: str):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
np_inp = np.float32(8)
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
s_host = s_hbm.with_memory_kind(host_memory_kind)
Expand All @@ -358,6 +372,8 @@ def test_device_put_numpy_scalar(self, host_memory_kind: str):

@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_python_scalar(self, host_memory_kind: str):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
py_scalar = float(8)
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
s_host = s_hbm.with_memory_kind(host_memory_kind)
Expand All @@ -372,6 +388,8 @@ def test_device_put_python_scalar(self, host_memory_kind: str):

@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_python_int(self, host_memory_kind: str):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
py_inp = 8
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
s_host = s_hbm.with_memory_kind(host_memory_kind)
Expand Down Expand Up @@ -399,6 +417,8 @@ def f(a, b):
out, np_inp * np_inp, s_dev, "device")

def test_parameter_streaming(self):
if jtu.test_device_matches(["gpu"]):
self.skipTest("This test does not work on GPU backend.")
_, s_host, np_inp, inp_host = _create_inputs(
(8, 2), P("x", "y"), mem_kind="pinned_host")
s_dev = s_host.with_memory_kind('device')
Expand All @@ -422,6 +442,8 @@ def f(a, b):
out2, np_inp * np_inp * 2, s_host, 'pinned_host')

def test_parameter_streaming_with_scalar_and_constant(self):
if jtu.test_device_matches(["gpu"]):
self.skipTest("This test does not work on GPU backend.")
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
scalar_inp = 1
s_host = NamedSharding(mesh, P(), memory_kind="pinned_host")
Expand Down Expand Up @@ -569,6 +591,8 @@ def f(x):
out_host, np_inp * 2, s_host, 'pinned_host')

def test_output_streaming_inside_scan(self):
if jtu.test_device_matches(["gpu"]):
self.skipTest("This test does not work on GPU backend.")
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
self.skipTest("This test requires an xla_version >= 2.")
mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z"))
Expand Down

0 comments on commit cc21245

Please sign in to comment.