-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initialization of Submodules Lifted with flax.nn.scan
#2754
Comments
Hey @daskol, sorry this took a while. I took you code and created a minimal working version: import flax.linen as nn
import jax
import jax.numpy as jnp
class MLP(nn.Module):
@nn.compact
def __call__(self, xs, _):
h = nn.Dense(features=2)(xs)
h = nn.relu(h)
h = nn.Dense(features=2)(xs)
return xs + h, None
class Transformer(nn.Module):
n_layers: int = 4
@nn.compact
def __call__(self, x):
ScanMLP = nn.scan(
target=MLP, variable_axes={'params': 0}, variable_broadcast=False,
split_rngs={'params': True}, length=self.n_layers)
x, _ = ScanMLP()(x, None)
return x
model = Transformer(n_layers=4)
y, variables = model.init_with_output(jax.random.PRNGKey(42), jnp.ones((1, 2)))
print(y.shape)
print(jax.tree_map(jnp.shape, variables)) What I think we can do is improve |
@cgarciae Thank you for your time. Eventually, I have managed to scan over modules with |
Hey @daskol, I am not sure I understand the need to use import flax.linen as nn
import jax
import jax.numpy as jnp
class MLP(nn.Module):
def setup(self):
self.dense1 = nn.Dense(features=2)
self.dense2 = nn.Dense(features=2)
def __call__(self, xs, _):
h = self.dense1(xs)
h = nn.relu(h)
h = self.dense2(h)
return xs + h, None
class Transformer(nn.Module):
n_layers: int = 4
def setup(self):
ScanMLP = nn.scan(
target=MLP, variable_axes={'params': 0}, variable_broadcast=False,
split_rngs={'params': True}, length=self.n_layers)
self.scan_mlp = ScanMLP()
def __call__(self, x):
x, _ = self.scan_mlp(x, None)
return x
model = Transformer(n_layers=4)
y, variables = model.init_with_output(jax.random.PRNGKey(42), jnp.ones((1, 2)))
print(y.shape)
print(jax.tree_map(jnp.shape, variables)) |
@cgarciae There are actually multiple issues with However, the original problem which I reported is that I need
So, replacing import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import orthogonal
class MLP(nn.Module):
var: float = 1.0
def setup(self):
self.dense1 = nn.Dense(features=2, kernel_init=orthogonal(self.var))
self.dense2 = nn.Dense(features=2, kernel_init=orthogonal(self.var))
def __call__(self, xs, _):
h = self.dense1(xs)
h = nn.relu(h)
h = self.dense2(h)
return xs + h, None
class Transformer(nn.Module):
n_layers: int = 4
def setup(self):
# 1. This works.
fn = MLP
# 2. This does not work.
def fn():
return MLP(1 / self.n_layers)
ScanMLP = nn.scan(target=fn, variable_axes={'params': 0},
variable_broadcast=False, split_rngs={'params': True},
length=self.n_layers)
self.scan_mlp = ScanMLP()
def __call__(self, x):
x, _ = self.scan_mlp(x, None)
return x
model = Transformer(n_layers=4)
y, variables = model.init_with_output(jax.random.PRNGKey(42), jnp.ones((1, 2)))
print(y.shape)
print(jax.tree_map(jnp.shape, variables)) UPD The most frequent exception thrown says that |
@daskol There are 2 ways to do this, the class version: class MLP(nn.Module):
var: float = 1.0
def setup(self):
self.dense1 = nn.Dense(features=2, kernel_init=orthogonal(self.var))
self.dense2 = nn.Dense(features=2, kernel_init=orthogonal(self.var))
def __call__(self, xs, _):
h = self.dense1(xs)
h = nn.relu(h)
h = self.dense2(h)
return xs + h, None
class Transformer(nn.Module):
n_layers: int = 4
def setup(self):
ScanMLP = nn.scan(target=MLP, variable_axes={'params': 0},
variable_broadcast=False, split_rngs={'params': True},
length=self.n_layers)
self.scan_mlp = ScanMLP(1 / self.n_layers)
def __call__(self, x):
x, _ = self.scan_mlp(x, None)
return x And the functional version: class MLP(nn.Module):
var: float = 1.0
def setup(self):
self.dense1 = nn.Dense(features=2, kernel_init=orthogonal(self.var))
self.dense2 = nn.Dense(features=2, kernel_init=orthogonal(self.var))
def __call__(self, xs):
h = self.dense1(xs)
h = nn.relu(h)
h = self.dense2(h)
return xs + h
class Transformer(nn.Module):
n_layers: int = 4
def setup(self):
def scan_fn(mlp, x, _):
return mlp(x), None
self.scan = nn.scan(target=scan_fn, variable_axes={'params': 0},
variable_broadcast=False, split_rngs={'params': True},
length=self.n_layers)
self.mlp = MLP(1 / self.n_layers)
def __call__(self, x):
x, _ = self.scan(self.mlp, x, None)
return x The general signature for either the Module's
where in this case |
@cgarciae Thank you very much! Now, it works perfectly.
I checked the ReadTheDocs. It says that the loop body should have the signature
What is the difference? Does option I found small typo in #2839 and left a comment. Could I ask you in advance to enable inter-sphinx links to
It is unclear to me where |
One more issue 😄 . Promise this is the last one. There are a lot of question about
flax.nn.scan
and RTD and existing GitHub issues do not solve them.With very deep model compilation times become insane and it takes about 1 hour to compile model for Nvidia runtime. So, I decided to prevent loop unrolling with
jax.lax.scan
and its lifting counterpartflax.nn.scan
. However, I faced multiple issues. Incomplete list of issues follows.args
andkwargs
to__call__
of submodule (in this caseMLP
).flax.nn.scan
as RTD says.flax.nn.scan
always returns(carry, args)
even if there is onlycarry
and noargs
.target
should be either a type ofnn.Module
or a function which acceptsnn.Module
(type?) as its first position argument.name
of modules inMLP
then an exception is thrown. It is a bit strange because all parameter trees merged to a single parameter tree.In this experiments flax v0.6.3 and jax v0.4.1 are used.
The text was updated successfully, but these errors were encountered: