Skip to content

Commit

Permalink
Merge pull request #3790 from google:nnx-graph-node-base
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 619594062
  • Loading branch information
Flax Authors committed Mar 27, 2024
2 parents 0ab0365 + 6948138 commit 6f1f1ef
Show file tree
Hide file tree
Showing 9 changed files with 331 additions and 261 deletions.
4 changes: 3 additions & 1 deletion flax/experimental/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
from .nnx.module import GraphDef as GraphDef
from .nnx.module import M as M
from .nnx.module import Module as Module
from .nnx.module import merge as merge
from .nnx.graph_utils import merge as merge
from .nnx.graph_utils import split as split
from .nnx.graph_utils import update as update
from .nnx.nn import initializers as initializers
from .nnx.nn.activations import celu as celu
from .nnx.nn.activations import elu as elu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def create_block(keys):

# call vmap over create_block, passing the split `params` key
# and immediately merge to get a Block instance
self.layers = nnx.merge(jax.vmap(create_block)(keys))
self.layers = nnx.merge(*jax.vmap(create_block)(keys))

def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array:
# fork Rngs, split keys into `n_layers`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def __init__(self, cfg: Config, *, rngs: nnx.Rngs):

if cfg.scanned:
self.layers = nnx.merge(
jax.vmap(lambda key: DecoderBlock(cfg, rngs=nnx.Rngs(key)).split())(
*jax.vmap(lambda key: DecoderBlock(cfg, rngs=nnx.Rngs(key)).split())(
jax.random.split(rngs.params(), cfg.layers)
)
)
Expand Down
Loading

0 comments on commit 6f1f1ef

Please sign in to comment.