Skip to content

Commit

Permalink
[linen] fix DenseGeneral init
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Apr 9, 2024
1 parent 2718455 commit 74e328e
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 23 deletions.
12 changes: 2 additions & 10 deletions flax/experimental/nnx/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import typing as tp
from types import MappingProxyType

import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
Expand Down Expand Up @@ -190,13 +189,7 @@ def __init__(
n_out_features = len(self.out_features)

def kernel_init_wrap(rng, shape, dtype):
flat_shape = (
np.prod(shape[:n_batch_axis])
* np.prod(shape[n_batch_axis : n_in_features + n_batch_axis]),
np.prod(shape[-n_out_features:]),
)
flat_shape = jax.tree_util.tree_map(int, flat_shape)
kernel = self.kernel_init(rng, flat_shape, dtype)
kernel = self.kernel_init(rng, shape, dtype)
if isinstance(kernel, variables.VariableMetadata):
kernel.raw_value = jnp.reshape(kernel.raw_value, shape)
else:
Expand All @@ -217,8 +210,7 @@ def kernel_init_wrap(rng, shape, dtype):
if self.use_bias:

def bias_init_wrap(rng, shape, dtype):
flat_shape = (int(np.prod(shape)),)
bias = self.bias_init(rng, flat_shape, dtype)
bias = self.bias_init(rng, shape, dtype)
if isinstance(bias, variables.VariableMetadata):
bias.raw_value = jnp.reshape(bias.raw_value, shape)
else:
Expand Down
15 changes: 2 additions & 13 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
Union,
)

import jax
import jax.numpy as jnp
import numpy as np
from jax import eval_shape, lax
Expand Down Expand Up @@ -142,13 +141,7 @@ def __call__(self, inputs: Array) -> Array:
n_axis, n_features = len(axis), len(features)

def kernel_init_wrap(rng, shape, dtype=jnp.float32):
flat_shape = (
np.prod(shape[:n_batch_dims])
* np.prod(shape[n_batch_dims : n_axis + n_batch_dims]),
np.prod(shape[-n_features:]),
)
flat_shape = jax.tree_util.tree_map(int, flat_shape)
kernel = self.kernel_init(rng, flat_shape, dtype)
kernel = self.kernel_init(rng, shape, dtype)
if isinstance(kernel, meta.AxisMetadata):
return meta.replace_boxed(kernel, jnp.reshape(kernel.unbox(), shape))
return jnp.reshape(kernel, shape)
Expand All @@ -171,11 +164,7 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32):
if self.use_bias:

def bias_init_wrap(rng, shape, dtype=jnp.float32):
flat_shape = (
np.prod(shape[:n_batch_dims]) * np.prod(shape[-n_features:]),
)
flat_shape = jax.tree_util.tree_map(int, flat_shape)
bias = self.bias_init(rng, flat_shape, dtype)
bias = self.bias_init(rng, shape, dtype)
if isinstance(bias, meta.AxisMetadata):
return meta.replace_boxed(bias, jnp.reshape(bias.unbox(), shape))
return jnp.reshape(bias, shape)
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ filterwarnings = [
"ignore:.*Deprecated call to.*pkg_resources.declare_namespace.*:DeprecationWarning",
# DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
"ignore:.*jax.tree_map is deprecated.*:DeprecationWarning",
# DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
"ignore:.*elementwise comparison failed; this will raise an error in the future.*:DeprecationWarning",
]

[tool.coverage.report]
Expand Down

0 comments on commit 74e328e

Please sign in to comment.