Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-JITable Initialization is not Allowed #2750

Closed
daskol opened this issue Dec 21, 2022 · 4 comments
Closed

Non-JITable Initialization is not Allowed #2750

daskol opened this issue Dec 21, 2022 · 4 comments

Comments

@daskol
Copy link

daskol commented Dec 21, 2022

As far as I understand one of the recommended ways to initialize network is application jax.jit on nn.Module.init_with_output. This results in the issue that some common pipelines use this technique and do not allow any non-jittable transformation during the stage. As an example, I am trying to estimate in some sense optimal parameters of weight initializer which requires adaptive integration (invocation of scipy.integrate.quad routine from QUADPACK). Basically, one could expect scipy.integrate.quad to be jittable and implemented on every backend but it is not the case.

import flax.linen as nn
import numpy as np
import scipy as sp
from flax.linen.initializers import variance_scaling


class Model(nn.Module):

    depth: int = 1000

    def setup(self):
        self.variance = sp.integrate.quad(fn, -np.inf, +np.inf, 1 / self.depth)

        def kernel_init():
            return variance_scaling(self.variance, 'fan_in', 'normal')

        self.layer = nn.Dense(features=10, kernel_init=kernel_init)

So, the following code fails.

model = Model()
jax.jit(model.init_with_output)(key, batch)
# jax._src.errors.ConcretizationTypeError:
#     Abstract tracer value encountered where concrete value is expected:
#         Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>

I tried to move initialization from setup to __post_init__ but it seems that parent's __post_init__ triggers tracer and raises the same exception.

class Model(nn.Module):

    depth: int = 1000

    def __post_init__(self):
        super().__init__()
        self.variance = sp.integrate.quad(fn, -np.inf, +np.inf, 1 / self.depth)

Are there any workarounds? Obvious one is to move a non-jittable computation out of module scope. From my perspective, this issue is a design flaw and should be fixed.

@daskol
Copy link
Author

daskol commented Dec 22, 2022

I tried to handle the issue via jax.experimental.host_callback.call but I didn't managed to it. Please, see details in jax-ml/jax#13762.

@marcvanzee
Copy link
Collaborator

Hi @daskol, I am not able to reproduce your problem, could you please provide a full example (include the implementation of __call__ and fn)? This works fine:

from scipy import integrate
import jax.numpy as jnp
import jax
from jax import random
import numpy as np

from flax import linen as nn

class Foo(nn.Module):
  power: float = 2

  def setup(self):
    self.variance = integrate.quad(lambda x, y: x**y, -np.inf, +np.inf, self.power)

    def kernel_init():
      return variance_scaling(self.variance, 'fan_in', 'normal')

    self.layer = nn.Dense(features=10, kernel_init=kernel_init)

  def __call__(self, x):
    a, b = self.variance
    return x + a

rng = random.PRNGKey(0)
model = Foo()
jax.jit(model.init_with_output)(random.PRNGKey(0), jnp.array(2.))

@daskol
Copy link
Author

daskol commented Dec 22, 2022

Please find full code in gists daskol@ef1cfd8fb. Running tests with the command fails.

pytest cnn_test.py

@marcvanzee
Copy link
Collaborator

Thanks! I think the problem occurs when you pass in a JAX array as input into the quad function and then jit compile it. JAX will trace through the function to compile it and use abstract values for the JAX arrays, but quad needs concrete arrays:

def fn(x):
  return quad(lambda a, b: a ** b, -np.inf, +np.inf, x)

def fn2():
  return quad(lambda a, b: a ** b, -np.inf, +np.inf, 0)

jax.jit(fn)(jnp.array(0))  # This fails: concrete vaue is expected but x is abstract when tracing
jax.jit(fn)(0)  # This fails as well: 0 is converted to JAX array and still abstract when tracing
jax.jit(fn2)()  # This passes, now quad receives a constant

While interesting, at this point the issue is somewhat detached from Flax, so I suggest you continue the discussion on the JAX issue and I am closing it here. Please reopen if anything changes!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants