Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] improve nnx.scan in_axes/out_axes #4157

Merged
merged 1 commit into from
Sep 3, 2024
Merged

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Aug 29, 2024

What does this PR do?

Makes nnx.scan's in_axes and out_axes API is more flexible in the following ways:

  • The function is no longer limited to returning carry, outputs, out_axes can use Carry to control its structure, by default its now (Carry, 0) to match current behavior.
  • Carry is no longer required to be present in in_axes/out_axes, however if Carry is present in in_axes it must also be present in out_axes and vice versa.
  • in_axes is not limited to a Sequence, as with vmap a single top-level integer or None can now be used to configure all arguments. A top-level Carry is also accepted by in_axes but its interpreted as (Carry,), it means that the function accepts a single argument that is the Carry.

Example: Scan over layers

state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None})

class MLP(nnx.Module):
  @nnx.split_rngs(splits=5)
  @nnx.vmap(in_axes=(state_axes, 0))
  def __init__(self, rngs: nnx.Rngs):
    self.linear = nnx.Linear(3, 3, rngs=rngs)

  @nnx.scan(in_axes=(state_axes, nnx.Carry), out_axes=nnx.Carry)
  def __call__(self, x: jax.Array):
    return nnx.gelu(self.linear(x))

module = MLP(nnx.Rngs(0))

x = jnp.ones((1, 3))
y = module(x)

assert y.shape == (1, 3)

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@copybara-service copybara-service bot merged commit 35c4edf into main Sep 3, 2024
18 checks passed
@copybara-service copybara-service bot deleted the nnx-improve-scan branch September 3, 2024 20:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants