Skip to content

Commit

Permalink
[nnx] use explicit Variables
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Mar 2, 2024
1 parent 0c006b3 commit 29aad7d
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions flax/experimental/nnx/nnx/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,14 +844,12 @@ def scan_apply(

# transpose axes state
scan_states = tuple(
jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), axes_state)
jax.tree_map(lambda x: jnp.moveaxis(x, axis, 0), axes_state)
for axes_state, axis in zip(scan_states, options.variable_axes.values())
)
# transpose axes arg
scan_args = jax.tree_util.tree_map(
lambda axis, node: jax.tree_util.tree_map(
lambda x: jnp.moveaxis(x, axis, 0), node
)
scan_args = jax.tree_map(
lambda axis, node: jax.tree_map(lambda x: jnp.moveaxis(x, axis, 0), node)
if axis is not None
else None,
options.in_args_axes,
Expand All @@ -864,17 +862,15 @@ def scan_apply(
args,
is_leaf=lambda x: x is None,
)
scan_kwargs = jax.tree_util.tree_map(
lambda axis, node: jax.tree_util.tree_map(
lambda x: jnp.moveaxis(x, axis, 0), node
)
scan_kwargs = jax.tree_map(
lambda axis, node: jax.tree_map(lambda x: jnp.moveaxis(x, axis, 0), node)
if axis is not None
else None,
options.in_kwargs_axes,
kwargs,
is_leaf=lambda x: x is None,
)
broadcast_kwargs = jax.tree_util.tree_map(
broadcast_kwargs = jax.tree_map(
lambda axis, node: None if axis is not None else node,
options.in_kwargs_axes,
kwargs,
Expand Down Expand Up @@ -933,14 +929,14 @@ def scan_fn(
split_keys, scan_states, scan_args, scan_kwargs = scan

# merge args and kwargs
args = jax.tree_util.tree_map(
args = jax.tree_map(
lambda axis, scan, broadcast: scan if axis is not None else broadcast,
options.in_args_axes,
scan_args,
broadcast_args,
is_leaf=lambda x: x is None,
)
kwargs = jax.tree_util.tree_map(
kwargs = jax.tree_map(
lambda axis, scan, broadcast: scan if axis is not None else broadcast,
options.in_kwargs_axes,
scan_kwargs,
Expand Down Expand Up @@ -1014,14 +1010,12 @@ def scan_fn(

# transpose axes state
scan_states = tuple(
jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), axes_state)
jax.tree_map(lambda x: jnp.moveaxis(x, 0, axis), axes_state)
for axes_state, axis in zip(scan_states, options.variable_axes.values())
)
# transpose axes arg
scan_out = jax.tree_util.tree_map(
lambda axis, node: jax.tree_util.tree_map(
lambda x: jnp.moveaxis(x, 0, axis), node
),
scan_out = jax.tree_map(
lambda axis, node: jax.tree_map(lambda x: jnp.moveaxis(x, 0, axis), node),
options.out_axes,
scan_out,
)
Expand Down Expand Up @@ -1479,16 +1473,16 @@ def vmap_apply(

# infer length
axis_sizes: tp.Set[int] = set()
args_sizes = jax.tree_util.tree_map(
lambda axis, node: jax.tree_util.tree_map(lambda x: x.shape[axis], node)
args_sizes = jax.tree_map(
lambda axis, node: jax.tree_map(lambda x: x.shape[axis], node)
if axis is not None
else None,
options.in_args_axes,
args,
is_leaf=lambda x: x is None,
)
kwargs_sizes = jax.tree_util.tree_map(
lambda axis, node: jax.tree_util.tree_map(lambda x: x.shape[axis], node)
kwargs_sizes = jax.tree_map(
lambda axis, node: jax.tree_map(lambda x: x.shape[axis], node)
if axis is not None
else None,
options.in_kwargs_axes,
Expand Down

0 comments on commit 29aad7d

Please sign in to comment.