Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flax.linen.remat with concrete=True doesn't work with jax 0.3.17 #2452

Closed
lucaslingle opened this issue Sep 10, 2022 · 1 comment · Fixed by #2457
Closed

flax.linen.remat with concrete=True doesn't work with jax 0.3.17 #2452

lucaslingle opened this issue Sep 10, 2022 · 1 comment · Fixed by #2457
Assignees
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)

Comments

@lucaslingle
Copy link

lucaslingle commented Sep 10, 2022

Problem you have encountered:

This may already be on the Flax team's radar, but I noticed that when using flax.linen.remat, setting concrete=True doesn't work with Jax 0.3.17, for the reasons discussed here.

As of version 0.6.0: flax.linen.remat
(1) passes the argument concrete=True to jax.remat, which leads to an error message.
(2) does not accept an argument static_argnums, as used in the latest jax.remat.

Interestingly, pip's constraint solver did not seem to be aware of this incompatibility; running pip install jax, flax allowed me to install flax==0.6.0 with jax==0.3.17, leading to the observed problem.

As a workaround, I've downgraded to jax==0.3.16, and am running jax.config.update("jax_new_checkpoint", False) at the top of my scripts, as suggested by the link above.

What you expected to happen:

To ensure compatibility with Jax's remat functionality, future versions of flax.linen.remat would ideally accept an argument static_argnums, which can be passed to the jax.remat implementation.

In the traceback triggered by Flax passing concrete=True, the Jax developers also remark that

If jax.numpy operations need to be performed on static arguments, we can use the jax.ensure_compile_time_eval() context manager.

which may also be relevant to the future design of flax.linen.remat.

Steps to reproduce:

The problem can be reproduced by running the script

import flax.linen as nn
import jax

class Foo(nn.Module):
    def setup(self):
        self.linear = nn.remat(nn.Dense, concrete=True)(100, use_bias=False)

    def __call__(self, inputs):
        return self.linear(inputs)

if __name__ == '__main__':
   rng = jax.random.PRNGKey(0)
   rng, sk1, sk2 = jax.random.split(rng, 3)
   foo = Foo()
   input = jax.random.normal(sk1, [1, 10])
   params = foo.init({"params": sk2}, input)["params"]
   out = foo.apply({"params": params}, input)

Logs, error messages, etc:

When I run the above script, I obtain the following traceback:

toggle to show
Traceback (most recent call last):
  File "/Users/lucaslingle/PycharmProjects/project123/src/project123/nn/generic_module.py", line 17, in <module>
    params = foo.init({"params": sk2}, input)["params"]
  File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/linen/module.py", line 1273, in init
    _, v_out = self.init_with_output(
  File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/linen/module.py", line 1229, in init_with_output
    return init_with_output(
  File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/core/scope.py", line 897, in wrapper
    return apply(fn, mutable=mutable, flags=init_flags)({}, *args, rngs=rngs,
  File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/core/scope.py", line 865, in wrapper
    y = fn(root, *args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/linen/module.py", line 1647, in scope_fn
    return fn(module.clone(parent=scope), *args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/linen/module.py", line 361, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/linen/module.py", line 657, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/lucaslingle/PycharmProjects/project123/src/project123/nn/generic_module.py", line 9, in __call__
    return self.linear(inputs)
  File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/linen/transforms.py", line 316, in wrapped_fn
    ret = trafo_fn(module_scopes, *args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/core/lift.py", line 213, in wrapper
    y, out_variable_groups_xs_t = fn(
  File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/core/lift.py", line 1177, in inner
    def rematted(variable_groups, rng_groups, *args, **kwargs):
  File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/jax/_src/api.py", line 3084, in checkpoint
    raise NotImplementedError(msg)
jax._src.traceback_util.UnfilteredStackTrace: NotImplementedError: The 'concrete' option to jax.checkpoint / jax.remat is deprecated; in its place, you can use its `static_argnums` option, and if necessary the `jax.ensure_compile_time_eval()` context manager.

For example, if using `concrete=True` for an `is_training` flag:

  from functools import partial

  @partial(jax.checkpoint, concrete=True)
  def foo(x, is_training):
    if is_training:
      return f(x)
    else:
      return g(x)

replace it with a use of `static_argnums`:

  @partial(jax.checkpoint, static_argnums=(1,))
  def foo(x, is_training):
    ...

If jax.numpy operations need to be performed on static arguments, we can use the `jax.ensure_compile_time_eval()` context manager. For example, we can replace this use of `concrete=True`
:
  @partial(jax.checkpoint, concrete=True)
  def foo(x, y):
    if y > 0:
      return f(x)
    else:
      return g(x)

with this combination of `static_argnums` and `jax.ensure_compile_time_eval()`:

  @partial(jax.checkpoint, static_argnums=(1,))
  def foo(x, y):
    with jax.ensure_compile_time_eval():
      y_pos = y > 0
    if y_pos:
      return f(x)
    else:
      return g(x)

See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/lucaslingle/PycharmProjects/project123/src/project123/nn/generic_module.py", line 17, in <module>
    params = foo.init({"params": sk2}, input)["params"]
  File "/Users/lucaslingle/PycharmProjects/project123/src/project123/nn/generic_module.py", line 9, in __call__
    return self.linear(inputs)
NotImplementedError: The 'concrete' option to jax.checkpoint / jax.remat is deprecated; in its place, you can use its `static_argnums` option, and if necessary the `jax.ensure_compile_time_eval()` context manager.

For example, if using `concrete=True` for an `is_training` flag:

  from functools import partial

  @partial(jax.checkpoint, concrete=True)
  def foo(x, is_training):
    if is_training:
      return f(x)
    else:
      return g(x)

replace it with a use of `static_argnums`:

  @partial(jax.checkpoint, static_argnums=(1,))
  def foo(x, is_training):
    ...

If jax.numpy operations need to be performed on static arguments, we can use the `jax.ensure_compile_time_eval()` context manager. For example, we can replace this use of `concrete=True`
:
  @partial(jax.checkpoint, concrete=True)
  def foo(x, y):
    if y > 0:
      return f(x)
    else:
      return g(x)

with this combination of `static_argnums` and `jax.ensure_compile_time_eval()`:

  @partial(jax.checkpoint, static_argnums=(1,))
  def foo(x, y):
    with jax.ensure_compile_time_eval():
      y_pos = y > 0
    if y_pos:
      return f(x)
    else:
      return g(x)

See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html

System information

  • OS Platform and Distribution: MacOS Catalina 10.15.7
  • Flax, jax, jaxlib versions: flax==0.6.0, jax==0.3.17, jaxlib==0.3.15
  • Python version: 3.10
  • GPU/TPU model and memory: N/A
  • CUDA version (if applicable): N/A
@cgarciae cgarciae added the Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) label Sep 12, 2022
@cgarciae
Copy link
Collaborator

Hey @lucaslingle, thanks for bringing this up! I've opened #2457 with a fix for this.

@cgarciae cgarciae self-assigned this Sep 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants