Skip to content

Commit

Permalink
Allow passing complex objects in static_args of faked pmaps.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 535270829
  • Loading branch information
hbq1 authored and ChexDev committed May 25, 2023
1 parent 9a7b5ef commit 42a39e2
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 5 deletions.
8 changes: 3 additions & 5 deletions chex/_src/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,11 @@ def wrapped_fn(*args, **kwargs):
if static_broadcasted_argnums:
# Make sure vmap does not try to map over `static_broadcasted_argnums`.
if isinstance(in_axes, int):
vmap_in_axes = jax.tree_util.tree_map(lambda _: in_axes, call_args)
vmap_in_axes = [in_axes] * len(call_args)
else:
vmap_in_axes = in_axes
vmap_in_axes = list(vmap_in_axes)
vmap_in_axes = list(in_axes)
for argnum in static_broadcasted_argnums:
vmap_in_axes[argnum] = jax.tree_util.tree_map(
lambda _: None, call_args[argnum])
vmap_in_axes[argnum] = None

# To protect the arguments from `static_broadcasted_argnums`,
# from turning into tracers (because of vmap), we capture the original
Expand Down
45 changes: 45 additions & 0 deletions chex/_src/fake_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""Tests for `fake.py`."""

import dataclasses
import functools

from absl.testing import absltest
Expand Down Expand Up @@ -263,6 +264,50 @@ def foo(x, multiplier, y, mode='bar'):
with self.assertRaises(ValueError):
result = func()

@parameterized.parameters(1, [1])
def test_pmap_with_complex_static_broadcasted_object(self, static_argnums):

@dataclasses.dataclass
class Multiplier:
x: int
y: int

def foo(x, multiplier, y):
if static_argnums == 1 or 1 in static_argnums:
# Verify that the static arguments are not replaced with tracers.
self.assertIsInstance(multiplier, Multiplier)

return x * multiplier.x + y * multiplier.y

with fake.fake_pmap_and_jit():
num_devices = jax.device_count()

# pmap over all available devices
transformed_foo = jax.pmap(
foo,
axis_size=num_devices,
static_broadcasted_argnums=static_argnums,
)
x, y = jax.random.randint(
jax.random.PRNGKey(27), (2, num_devices, 3, 5), 0, 10
)

# Test 1.
mult = Multiplier(x=2, y=7)
asserts.assert_trees_all_equal(
transformed_foo(x, mult, y),
foo(x, mult, y),
x * mult.x + y * mult.y,
)

# Test 2.
mult = Multiplier(x=72, y=21)
asserts.assert_trees_all_equal(
transformed_foo(x, mult, y),
foo(x, mult, y),
x * mult.x + y * mult.y,
)

@parameterized.named_parameters([
('fake_nothing', False, False),
('fake_pmap', True, False),
Expand Down

0 comments on commit 42a39e2

Please sign in to comment.