Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call user-defined variable transforms before determining axis size in nn.vmap. #4026

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,17 +790,28 @@ def vmap(
rng_axes = tuple(0 if rng_split else None for rng_split in rng_splits)

def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args):
# optional user-defined variable transform on the way in
new_variable_groups = []
for var_group, axis in zip(variable_groups, variable_in_axes):
if axis is not None:
new_variable_groups.append(
meta.remove_axis(var_group, axis, metadata_params)
)
else:
new_variable_groups.append(var_group)
variable_groups = tuple(new_variable_groups)

# split rngs
def find_axis_size(axis, x):
if axis is not None:
leaves = jax.tree_util.tree_leaves(x)
if leaves:
return leaves[0].shape[axis]
return ()

# split rngs
axis_sizes = jax.tree_util.tree_map(
find_axis_size, (variable_in_axes, in_axes), (variable_groups, args),
is_leaf=lambda x: x is None
find_axis_size, (variable_in_axes, in_axes), (variable_groups, args),
is_leaf=lambda x: x is None
)
axis_sizes = set(jax.tree_util.tree_leaves(axis_sizes))
if axis_size is None and len(axis_sizes) == 1:
Expand All @@ -823,16 +834,6 @@ def find_axis_size(axis, x):
for rng_group, split in zip(rng_groups, rng_splits)
)

new_variable_groups = []
for var_group, axis in zip(variable_groups, variable_in_axes):
if axis is not None:
new_variable_groups.append(
meta.remove_axis(var_group, axis, metadata_params)
)
else:
new_variable_groups.append(var_group)
variable_groups = tuple(new_variable_groups)

@functools.partial(
jax.vmap,
in_axes=(variable_in_axes, rng_axes, in_axes),
Expand All @@ -847,6 +848,7 @@ def mapped(variable_groups, rng_groups, args):
y = fn(scope, *args)
return y, repack_fn(scope)

# optional user-defined variable transform on the way out
y, vars_out = mapped(variable_groups, rng_groups, args)
new_vars_out = []
for var_group, axis in zip(vars_out, variable_out_axes):
Expand Down
47 changes: 45 additions & 2 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

from functools import partial
import operator
from typing import Any, Callable, Sequence
from typing import Any, Callable, Dict, Sequence
import unittest

from absl.testing import absltest, parameterized
from flax import errors
from flax import linen as nn
from flax import serialization
from flax.core import copy, freeze
from flax import struct
from flax.core import copy, freeze, AxisMetadata
from flax.linen.transforms import _HashableProxy
import jax
from jax import random
Expand Down Expand Up @@ -2533,6 +2534,48 @@ def comparison_fn(x, y):
self.assertTrue(tree_allclose(jax.grad(comparison_fn, 0)(x, y), x_grad))
self.assertTrue(tree_allclose(jax.grad(comparison_fn, 1)(x, y), y_grad))

def test_vmap_add_remove_axis_transforms(self):
class BoxedData(struct.PyTreeNode, AxisMetadata):
value: Any
def unbox(self):
return self.value
def replace_boxed(self, val):
return self.replace(value=val)
def add_axis(self, index: int, params: Dict[Any, Any]):
value = jnp.mean(self.value, axis=index)
return self.replace(value=value)
def remove_axis(self, index: int, params: Dict[Any, Any]):
value_shape = list(self.value.shape)
value_shape.insert(index, params['axis_size'])
value = jnp.broadcast_to(self.value, value_shape)
return self.replace(value=value)

class Top(nn.Module):
@nn.compact
def __call__(self, x):
VFoo = nn.vmap(
Foo, in_axes=0, out_axes=0, variable_axes={'params':0, 'aux': 0},
metadata_params={'axis_size': x.shape[0]},
)
vfoo = VFoo(name="vfoo")
y = vfoo(x)
y = vfoo(x)
assert vfoo.variables['aux']['v'].value.shape == ()
return y

class Foo(nn.Module):
@nn.compact
def __call__(self, x):
if self.has_variable('aux', 'v'):
assert self.variables['aux']['v'].value.shape == ()
boxed_v = self.variable('aux', 'v', lambda: BoxedData(jnp.ones(())))
assert self.variables['aux']['v'].value.shape == ()
return x

vs = Top().init(random.key(0), jnp.ones((2,5)))
y = Top().apply(vs, jnp.ones((2, 5)))
assert vs['aux']['vfoo']['v'].value.shape == ()


if __name__ == '__main__':
absltest.main()
Loading