Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialization of Submodules Lifted with flax.nn.scan #2754

Closed
daskol opened this issue Dec 22, 2022 · 6 comments · Fixed by #2839
Closed

Initialization of Submodules Lifted with flax.nn.scan #2754

daskol opened this issue Dec 22, 2022 · 6 comments · Fixed by #2839
Assignees

Comments

@daskol
Copy link

daskol commented Dec 22, 2022

One more issue 😄 . Promise this is the last one. There are a lot of question about flax.nn.scan and RTD and existing GitHub issues do not solve them.

With very deep model compilation times become insane and it takes about 1 hour to compile model for Nvidia runtime. So, I decided to prevent loop unrolling with jax.lax.scan and its lifting counterpart flax.nn.scan. However, I faced multiple issues. Incomplete list of issues follows.

  1. There is no clear way to initialize scanned submodules. I came up with solution to pass everything as args and kwargs to __call__ of submodule (in this case MLP).
  2. There is no keyword argument of flax.nn.scan as RTD says.
  3. Func flax.nn.scan always returns (carry, args) even if there is only carry and no args.
  4. RTD says that target should be either a type of nn.Module or a function which accepts nn.Module (type?) as its first position argument.
  5. If one specified name of modules in MLP then an exception is thrown. It is a bit strange because all parameter trees merged to a single parameter tree.
import flax.linen as nn
import jax
import jax.numpy as jnp


def initializer(val):
    def init(key, shape, dtype):
        return jnp.full(shape, val, dtype)

    return init


class MLP(nn.Module):

    @nn.compact
    def __call__(self, xs, var):
        h = nn.Dense(features=2, kernel_init=initializer(var))(xs)
        h = nn.relu(h)
        h = nn.Dense(features=2, kernel_init=initializer(var))(xs)
        return xs + h, None


class Transformer(nn.Module):

    length: int = 3

    def setup(self):
        def fn(self, *args, **kwargs):
            return MLP(self, *args, **kwargs)

        # FAIL: Function instead of derived type from nn.Module does not work.
        #
        #   ScanMLP = nn.scan(target=fn, ...)
        #
        #   jax._src.traceback_util.UnfilteredStackTrace: TypeError:
        #   Transformer.setup.<locals>.fn() missing 1 required positional
        #   argument: 'self'

        # OK: No problems.
        ScanMLP = nn.scan(target=fn,
                          variable_axes={'params': 0},
                          variable_broadcast=False,
                          split_rngs={'params': True},
                          length=self.length)

        self.vars = jnp.arange(self.length)  # e.g. [0, 1, 2]
        self.mlp = ScanMLP()  # FAIL: ScanMLP(self.vars)

    @nn.compact  # OK: This decorator does nothing. Why?
    def __call__(self, xs):
        carry, out = self.mlp(xs, self.vars)  # OK: Axis 0 (implicitely).
        assert out is None
        return carry


model = Transformer(length=1250)
ys, state = jax.jit(model.init_with_output)(jax.random.PRNGKey(42),
                                            jnp.ones((3, 2)))
kernel = state['params']['mlp']['Dense_0']['kernel']
assert (kernel[0, ...] == jnp.zeros((2, 2))).all()
assert (kernel[1, ...] == jnp.ones((2, 2))).all()

In this experiments flax v0.6.3 and jax v0.4.1 are used.

@zaxtax zaxtax self-assigned this Dec 22, 2022
@zaxtax zaxtax assigned cgarciae and unassigned zaxtax Jan 6, 2023
@cgarciae
Copy link
Collaborator

cgarciae commented Feb 2, 2023

Hey @daskol, sorry this took a while. I took you code and created a minimal working version:

import flax.linen as nn
import jax
import jax.numpy as jnp

class MLP(nn.Module):
    @nn.compact
    def __call__(self, xs, _):
        h = nn.Dense(features=2)(xs)
        h = nn.relu(h)
        h = nn.Dense(features=2)(xs)
        return xs + h, None

class Transformer(nn.Module):
    n_layers: int = 4

    @nn.compact
    def __call__(self, x):
        ScanMLP = nn.scan(
            target=MLP, variable_axes={'params': 0}, variable_broadcast=False,
            split_rngs={'params': True}, length=self.n_layers)
        x, _ = ScanMLP()(x, None)
        return x

model = Transformer(n_layers=4)
y, variables = model.init_with_output(jax.random.PRNGKey(42), jnp.ones((1, 2)))

print(y.shape)
print(jax.tree_map(jnp.shape, variables))

What I think we can do is improve nn.scan's documentation to show how to do this correctly, I am inclined to add a modified version of this example.

@daskol
Copy link
Author

daskol commented Feb 2, 2023

@cgarciae Thank you for your time. Eventually, I have managed to scan over modules with @nn.compact but another issue appears (see #2750 and underlying jax-ml/jax#13762). Briefly, the issue is that model initialization requires estimation some parameters (bias and variance of initializer) with non-jax code what can be done with @nn.compact decorator. Is it possible to rewrite you example in order to use setup instead of @nn.compact?

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 2, 2023

Hey @daskol, I am not sure I understand the need to use setup entirely (I did read the issue about quad) but in any case here is a setup version:

import flax.linen as nn
import jax
import jax.numpy as jnp

class MLP(nn.Module):
    def setup(self):
        self.dense1 = nn.Dense(features=2)
        self.dense2 = nn.Dense(features=2)

    def __call__(self, xs, _):
        h = self.dense1(xs)
        h = nn.relu(h)
        h = self.dense2(h)
        return xs + h, None

class Transformer(nn.Module):
    n_layers: int = 4

    def setup(self):
        ScanMLP = nn.scan(
            target=MLP, variable_axes={'params': 0}, variable_broadcast=False,
            split_rngs={'params': True}, length=self.n_layers)
        self.scan_mlp = ScanMLP()
    
    def __call__(self, x):
        x, _ = self.scan_mlp(x, None)
        return x

model = Transformer(n_layers=4)
y, variables = model.init_with_output(jax.random.PRNGKey(42), jnp.ones((1, 2)))

print(y.shape)
print(jax.tree_map(jnp.shape, variables))

@daskol
Copy link
Author

daskol commented Feb 2, 2023

@cgarciae There are actually multiple issues with flax.nn.scan. I tried to solve the issue with my understanding of how flax.nn.scan works but faced to some unexpected behaviours and reported them here (FAIL and OK comment in the issue description).

However, the original problem which I reported is that I need

  1. apply a flax.nn.scan to construct thousands of layers with special initialization parameters (variance);
  2. and these parameters, in general, depend on the layer depth.

So, replacing target=MLP with target=lambda: MLP(variance) in your last snippet does not work and the code throws an exception. The full examples based on your code follows.

import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import orthogonal


class MLP(nn.Module):

    var: float = 1.0

    def setup(self):
        self.dense1 = nn.Dense(features=2, kernel_init=orthogonal(self.var))
        self.dense2 = nn.Dense(features=2, kernel_init=orthogonal(self.var))

    def __call__(self, xs, _):
        h = self.dense1(xs)
        h = nn.relu(h)
        h = self.dense2(h)
        return xs + h, None

class Transformer(nn.Module):
    n_layers: int = 4

    def setup(self):
        # 1. This works.
        fn = MLP
        # 2. This does not work.
        def fn():
            return MLP(1 / self.n_layers)
        ScanMLP = nn.scan(target=fn, variable_axes={'params': 0},
            variable_broadcast=False, split_rngs={'params': True},
            length=self.n_layers)
        self.scan_mlp = ScanMLP()

    def __call__(self, x):
        x, _ = self.scan_mlp(x, None)
        return x

model = Transformer(n_layers=4)
y, variables = model.init_with_output(jax.random.PRNGKey(42), jnp.ones((1, 2)))

print(y.shape)
print(jax.tree_map(jnp.shape, variables))

UPD The most frequent exception thrown says that self is missing. I completely does not understand how it could be despite that I dug into flax sources and read a ton of code about how scopes work and how they are associated to modules and functional routines. It is still unclear for me why self is out of context captured by inner function fn.

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 3, 2023

@daskol There are 2 ways to do this, the class version:

class MLP(nn.Module):
    var: float = 1.0

    def setup(self):
        self.dense1 = nn.Dense(features=2, kernel_init=orthogonal(self.var))
        self.dense2 = nn.Dense(features=2, kernel_init=orthogonal(self.var))

    def __call__(self, xs, _):
        h = self.dense1(xs)
        h = nn.relu(h)
        h = self.dense2(h)
        return xs + h, None

class Transformer(nn.Module):
    n_layers: int = 4

    def setup(self):
        ScanMLP = nn.scan(target=MLP, variable_axes={'params': 0},
            variable_broadcast=False, split_rngs={'params': True},
            length=self.n_layers)
        self.scan_mlp = ScanMLP(1 / self.n_layers)

    def __call__(self, x):
        x, _ = self.scan_mlp(x, None)
        return x

And the functional version:

class MLP(nn.Module):
    var: float = 1.0

    def setup(self):
        self.dense1 = nn.Dense(features=2, kernel_init=orthogonal(self.var))
        self.dense2 = nn.Dense(features=2, kernel_init=orthogonal(self.var))

    def __call__(self, xs):
        h = self.dense1(xs)
        h = nn.relu(h)
        h = self.dense2(h)
        return xs + h

class Transformer(nn.Module):
    n_layers: int = 4

    def setup(self):
        def scan_fn(mlp, x, _):
            return mlp(x), None
        self.scan = nn.scan(target=scan_fn, variable_axes={'params': 0},
            variable_broadcast=False, split_rngs={'params': True},
            length=self.n_layers)
        self.mlp = MLP(1 / self.n_layers)

    def __call__(self, x):
        x, _ = self.scan(self.mlp, x, None)
        return x

The general signature for either the Module's __call__ method or the scan_fn function is

(module, carry, xs) -> carry, ys

where in this case xs = ys = None. I will try to improve the example for the functional version in the docs for nn.scan.

@daskol
Copy link
Author

daskol commented Feb 3, 2023

@cgarciae Thank you very much! Now, it works perfectly.

The general signature for either the Module's ...

I checked the ReadTheDocs. It says that the loop body should have the signature

(scope, body, carry, *xs) -> (carry, ys)

What is the difference? Does option variable_broadcast of flax.nn.scan affect the signature (assuming scope is the same as body)?

I found small typo in #2839 and left a comment. Could I ask you in advance to enable inter-sphinx links to jax and clarify what vmap or scan documentation says. For example, the current docs of flax.linen.flax say

To improve consistency with vmap, this version of scan uses in_axes and out_axes to determine which arguments are scanned over and along which axis.

It is unclear to me where vmap refers to. To jax.vmap or to flax.linen.vmap?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants