-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Non-JITable Initialization is not Allowed #2750
Comments
I tried to handle the issue via |
Hi @daskol, I am not able to reproduce your problem, could you please provide a full example (include the implementation of from scipy import integrate
import jax.numpy as jnp
import jax
from jax import random
import numpy as np
from flax import linen as nn
class Foo(nn.Module):
power: float = 2
def setup(self):
self.variance = integrate.quad(lambda x, y: x**y, -np.inf, +np.inf, self.power)
def kernel_init():
return variance_scaling(self.variance, 'fan_in', 'normal')
self.layer = nn.Dense(features=10, kernel_init=kernel_init)
def __call__(self, x):
a, b = self.variance
return x + a
rng = random.PRNGKey(0)
model = Foo()
jax.jit(model.init_with_output)(random.PRNGKey(0), jnp.array(2.)) |
Please find full code in gists daskol@ef1cfd8fb. Running tests with the command fails. pytest cnn_test.py |
Thanks! I think the problem occurs when you pass in a JAX array as input into the def fn(x):
return quad(lambda a, b: a ** b, -np.inf, +np.inf, x)
def fn2():
return quad(lambda a, b: a ** b, -np.inf, +np.inf, 0)
jax.jit(fn)(jnp.array(0)) # This fails: concrete vaue is expected but x is abstract when tracing
jax.jit(fn)(0) # This fails as well: 0 is converted to JAX array and still abstract when tracing
jax.jit(fn2)() # This passes, now quad receives a constant While interesting, at this point the issue is somewhat detached from Flax, so I suggest you continue the discussion on the JAX issue and I am closing it here. Please reopen if anything changes! |
As far as I understand one of the recommended ways to initialize network is application
jax.jit
onnn.Module.init_with_output
. This results in the issue that some common pipelines use this technique and do not allow any non-jittable transformation during the stage. As an example, I am trying to estimate in some sense optimal parameters of weight initializer which requires adaptive integration (invocation ofscipy.integrate.quad
routine from QUADPACK). Basically, one could expectscipy.integrate.quad
to be jittable and implemented on every backend but it is not the case.So, the following code fails.
I tried to move initialization from
setup
to__post_init__
but it seems that parent's__post_init__
triggers tracer and raises the same exception.Are there any workarounds? Obvious one is to move a non-jittable computation out of module scope. From my perspective, this issue is a design flaw and should be fixed.
The text was updated successfully, but these errors were encountered: