Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scope.path names vs variables structure inconsistencies #2354

Closed
cgarciae opened this issue Aug 1, 2022 · 1 comment
Closed

Scope.path names vs variables structure inconsistencies #2354

cgarciae opened this issue Aug 1, 2022 · 1 comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.

Comments

@cgarciae
Copy link
Collaborator

cgarciae commented Aug 1, 2022

While developing #2316, Scope.paths where capture via a new mechanism made for tabulate, such mechanism assumes that there is a 1:1 correspondence between path names and variable structure, however a couple of workarounds had to be made as it was found that this breaks under certain circumstances. Here are some minimal reproducible examples showing this unexpected behavior.

1. Lifted Modules

When using lifted Modules, Scope.path names use the following notation:

<transformation_name>(<module_name>)

So you get names like scan(ScanLSTMCell_0), whereas in the variable structure only ScanLSTMCell_0 appears. Code bellow shows the difference between a different paths names and the corresponding variable structure for nn.scan + LSTMCell example:

Code
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import random

PATHS = []

class LSTMCell(nn.LSTMCell):
    def __call__(self, carry, inputs):
        PATHS.append(self.scope.path)
        return super().__call__(carry, inputs)

class LSTM(nn.Module):
    out_feat: int

    @nn.compact
    def __call__(self, x):
        PATHS.append(self.scope.path)
        carry = nn.LSTMCell.initialize_carry(
            random.PRNGKey(0), x.shape[:1], self.out_feat
        )
        Cell = nn.scan(
            LSTMCell,
            variable_broadcast="params",
            split_rngs={"params": False},
            variable_axes={"intermediates": 1},
            in_axes=1,
            out_axes=1,
        )
        return Cell()(carry, x)

lstm = LSTM(out_feat=128)
variables = lstm.init(random.PRNGKey(0), jnp.ones((32, 128, 64)))

print(PATHS, "\n")
print(jax.tree_map(lambda x: x.shape, variables))
[(), ('scan(ScanLSTMCell_0)',), ('scan(ScanLSTMCell_0)',)]

FrozenDict({
    params: {
        ScanLSTMCell_0: {
            hf: {...},
            hg: {...},
            hi: {...},
            ho: {...},
            if: {...},
            ig: {...},
            ii: {...},
            io: {...},
        },
    },
})

2. Module reuse

When reusing a module (calling the same module more that once) there is some weird behavior where path path might be wrong after the first call.

Using Setup

Code bellow creates a CNN module, that has a ConvBlock submodule that calls Conv. Here submodules are created during setup and CNN calls self.block twice.

Code
import jax
import jax.numpy as jnp
from flax import linen as nn

PATHS = []

class Conv(nn.Conv):
    def __call__(self, *args, **kwargs):
        PATHS.append(("inside conv", self.scope.path))
        return super().__call__(*args, **kwargs)

class ConvBlock(nn.Module):
  def setup(self) -> None:
    self.conv = Conv(32, [3, 3])

  def __call__(self, x):
    PATHS.append(("ConvBlock start", self.scope.path))
    x = self.conv(x)
    PATHS.append(("after conv", self.scope.path))
    return x

class CNN(nn.Module):
  def setup(self):
    self.block = ConvBlock()
  def __call__(self, x):
    x = self.block(x)
    x = self.block(x)
    return x

x = jnp.ones((4, 28, 28, 32))
variables = CNN().init(jax.random.PRNGKey(0), x)

for p in PATHS:
  print(p)
('ConvBlock start', ('block',))     # correct
('inside conv', ('block', 'conv'))  # correct
('ConvBlock end', ('block',))       # correct
('ConvBlock start', ('block',))     # correct
('inside conv', ())                 # wrong
('ConvBlock end', ('block',))       # correct

Notice that the 'inside conv' path is empty the second time self.block is called.

Using nn.compact

If you use nn.compact instead of setup path names are wrong in a different way. Here block is instantiated inside CNN.__call__ at the beginning and used twice.

Code
import jax
import jax.numpy as jnp
from flax import linen as nn

PATHS = []

class Conv(nn.Conv):
    def __call__(self, *args, **kwargs):
        PATHS.append(("inside conv", self.scope.path))
        return super().__call__(*args, **kwargs)

class ConvBlock(nn.Module):
  @nn.compact
  def __call__(self, x):
    PATHS.append(("ConvBlock start", self.scope.path))
    x = Conv(32, [3, 3])(x)
    PATHS.append(("ConvBlock end", self.scope.path))
    return x

class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    block = ConvBlock()
    x = block(x)
    x = block(x)
    return x

x = jnp.ones((4, 28, 28, 32))
variables = CNN().init(jax.random.PRNGKey(0), x)

for p in PATHS:
  print(p)
('ConvBlock start', ('ConvBlock_0',))      # correct
('inside conv', ('ConvBlock_0', 'Conv_0')) # correct
('ConvBlock end', ('ConvBlock_0',))        # correct
('ConvBlock start', ())                    # wrong
('inside conv', ('Conv_0',))               # wrong
('ConvBlock end', ())                      # wrong

Notice that now all path are wrong the second time around.

@cgarciae cgarciae changed the title Scope.path vs variables structure inconsistencies Scope.path names vs variables structure inconsistencies Aug 1, 2022
@cgarciae cgarciae added the Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. label Aug 1, 2022
@cgarciae
Copy link
Collaborator Author

cgarciae commented Aug 3, 2022

Case 1 is part of the functional core API so its not really an error, just a mismatch between the functional API and the Module API.
Case 2 was fixed by #2360.

@cgarciae cgarciae closed this as completed Aug 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
Projects
None yet
Development

No branches or pull requests

1 participant