You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This may already be on the Flax team's radar, but I noticed that when using flax.linen.remat, setting concrete=True doesn't work with Jax 0.3.17, for the reasons discussed here.
As of version 0.6.0: flax.linen.remat
(1) passes the argument concrete=True to jax.remat, which leads to an error message.
(2) does not accept an argument static_argnums, as used in the latest jax.remat.
Interestingly, pip's constraint solver did not seem to be aware of this incompatibility; running pip install jax, flax allowed me to install flax==0.6.0 with jax==0.3.17, leading to the observed problem.
As a workaround, I've downgraded to jax==0.3.16, and am running jax.config.update("jax_new_checkpoint", False) at the top of my scripts, as suggested by the link above.
What you expected to happen:
To ensure compatibility with Jax's remat functionality, future versions of flax.linen.remat would ideally accept an argument static_argnums, which can be passed to the jax.remat implementation.
In the traceback triggered by Flax passing concrete=True, the Jax developers also remark that
If jax.numpy operations need to be performed on static arguments, we can use the jax.ensure_compile_time_eval() context manager.
which may also be relevant to the future design of flax.linen.remat.
Steps to reproduce:
The problem can be reproduced by running the script
When I run the above script, I obtain the following traceback:
toggle to show
Traceback (most recent call last):
File "/Users/lucaslingle/PycharmProjects/project123/src/project123/nn/generic_module.py", line 17, in <module>
params = foo.init({"params": sk2}, input)["params"]
File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/linen/module.py", line 1273, in init
_, v_out = self.init_with_output(
File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/linen/module.py", line 1229, in init_with_output
return init_with_output(
File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/core/scope.py", line 897, in wrapper
return apply(fn, mutable=mutable, flags=init_flags)({}, *args, rngs=rngs,
File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/core/scope.py", line 865, in wrapper
y = fn(root, *args, **kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/linen/module.py", line 1647, in scope_fn
return fn(module.clone(parent=scope), *args, **kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/linen/module.py", line 361, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/linen/module.py", line 657, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/Users/lucaslingle/PycharmProjects/project123/src/project123/nn/generic_module.py", line 9, in __call__
return self.linear(inputs)
File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/linen/transforms.py", line 316, in wrapped_fn
ret = trafo_fn(module_scopes, *args, **kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/core/lift.py", line 213, in wrapper
y, out_variable_groups_xs_t = fn(
File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/flax/core/lift.py", line 1177, in inner
def rematted(variable_groups, rng_groups, *args, **kwargs):
File "/Users/lucaslingle/opt/miniconda3/envs/project123/lib/python3.10/site-packages/jax/_src/api.py", line 3084, in checkpoint
raise NotImplementedError(msg)
jax._src.traceback_util.UnfilteredStackTrace: NotImplementedError: The 'concrete' option to jax.checkpoint / jax.remat is deprecated; in its place, you can use its `static_argnums` option, and if necessary the `jax.ensure_compile_time_eval()` context manager.
For example, if using `concrete=True` for an `is_training` flag:
from functools import partial
@partial(jax.checkpoint, concrete=True)
def foo(x, is_training):
if is_training:
return f(x)
else:
return g(x)
replace it with a use of `static_argnums`:
@partial(jax.checkpoint, static_argnums=(1,))
def foo(x, is_training):
...
If jax.numpy operations need to be performed on static arguments, we can use the `jax.ensure_compile_time_eval()` context manager. For example, we can replace this use of `concrete=True`
:
@partial(jax.checkpoint, concrete=True)
def foo(x, y):
if y > 0:
return f(x)
else:
return g(x)
with this combination of `static_argnums` and `jax.ensure_compile_time_eval()`:
@partial(jax.checkpoint, static_argnums=(1,))
def foo(x, y):
with jax.ensure_compile_time_eval():
y_pos = y > 0
if y_pos:
return f(x)
else:
return g(x)
See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/lucaslingle/PycharmProjects/project123/src/project123/nn/generic_module.py", line 17, in <module>
params = foo.init({"params": sk2}, input)["params"]
File "/Users/lucaslingle/PycharmProjects/project123/src/project123/nn/generic_module.py", line 9, in __call__
return self.linear(inputs)
NotImplementedError: The 'concrete' option to jax.checkpoint / jax.remat is deprecated; in its place, you can use its `static_argnums` option, and if necessary the `jax.ensure_compile_time_eval()` context manager.
For example, if using `concrete=True` for an `is_training` flag:
from functools import partial
@partial(jax.checkpoint, concrete=True)
def foo(x, is_training):
if is_training:
return f(x)
else:
return g(x)
replace it with a use of `static_argnums`:
@partial(jax.checkpoint, static_argnums=(1,))
def foo(x, is_training):
...
If jax.numpy operations need to be performed on static arguments, we can use the `jax.ensure_compile_time_eval()` context manager. For example, we can replace this use of `concrete=True`
:
@partial(jax.checkpoint, concrete=True)
def foo(x, y):
if y > 0:
return f(x)
else:
return g(x)
with this combination of `static_argnums` and `jax.ensure_compile_time_eval()`:
@partial(jax.checkpoint, static_argnums=(1,))
def foo(x, y):
with jax.ensure_compile_time_eval():
y_pos = y > 0
if y_pos:
return f(x)
else:
return g(x)
See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html
System information
OS Platform and Distribution: MacOS Catalina 10.15.7
Problem you have encountered:
This may already be on the Flax team's radar, but I noticed that when using flax.linen.remat, setting concrete=True doesn't work with Jax 0.3.17, for the reasons discussed here.
As of version 0.6.0: flax.linen.remat
(1) passes the argument
concrete=True
tojax.remat
, which leads to an error message.(2) does not accept an argument
static_argnums
, as used in the latestjax.remat
.Interestingly, pip's constraint solver did not seem to be aware of this incompatibility; running
pip install jax, flax
allowed me to install flax==0.6.0 with jax==0.3.17, leading to the observed problem.As a workaround, I've downgraded to jax==0.3.16, and am running
jax.config.update("jax_new_checkpoint", False)
at the top of my scripts, as suggested by the link above.What you expected to happen:
To ensure compatibility with Jax's remat functionality, future versions of flax.linen.remat would ideally accept an argument
static_argnums
, which can be passed to the jax.remat implementation.In the traceback triggered by Flax passing
concrete=True
, the Jax developers also remark thatwhich may also be relevant to the future design of flax.linen.remat.
Steps to reproduce:
The problem can be reproduced by running the script
Logs, error messages, etc:
When I run the above script, I obtain the following traceback:
toggle to show
System information
MacOS Catalina 10.15.7
flax==0.6.0, jax==0.3.17, jaxlib==0.3.15
3.10
N/A
N/A
The text was updated successfully, but these errors were encountered: