Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

leaked trace error when using custom_jvp, scan, eqx.Module, and jit #745

Closed
ToshiyukiBandai opened this issue Jun 4, 2024 · 13 comments · Fixed by #754
Closed

leaked trace error when using custom_jvp, scan, eqx.Module, and jit #745

ToshiyukiBandai opened this issue Jun 4, 2024 · 13 comments · Fixed by #754
Labels
question User queries

Comments

@ToshiyukiBandai
Copy link

ToshiyukiBandai commented Jun 4, 2024

Hi all,

I encountered an leaked trace error when I used custom_jvp, scan, eqx.Module, and jit at the same time with the latest version of JAX and equinox (JAX=0.4.28, equinox = 0.11.4). I did not have such an error before with older versions, but I could not reproduce it. The below example gives an leaked trace error when evaluating gradient of the function loss that has a eqx.Module class as an argument. It seems like params in the class Fun is leaked. It gave no error if I used a function as an argument. I would appreciate any help!

Updated:
params in the class Fun was originally a Python float, but now it is jax.array.

import functools as ft
import jax
from jax import jit, grad, lax, custom_jvp
import jax.numpy as jnp
import equinox as eqx

print(eqx.__version__) # 0.11.4

print(jax.__version__) # 0.4.28

platform = 'cpu'
# platform = 'gpu'
jax.config.update('jax_platform_name', platform)
print(jax.devices())

class Fun(eqx.Module):
    param: jnp.array
    def __call__(self, x):
        return self.param * jnp.sin(x) 

# fun = jnp.sin # this is okay

fun = Fun(jnp.array(1.0))

fun(1.0)

@ft.partial(custom_jvp, nondiff_argnums=(0,))
def f(fun, x, y):
    return fun(x)*y

@f.defjvp
def f_jvp(fun, primals, tangents):
    x, y = primals
    x_dot, y_dot = tangents
    primal_out = f(fun, x, y)
    tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
    return primal_out, tangent_out

f(fun, 2.0, 3.0)

grad_f = eqx.filter_jit(grad(f, argnums = (1,)))
# grad_f = jit(grad(f, argnums = (1,))) # this works too

grad(f, argnums = (1))(fun, 2., 3.)

def fun2(fun, x, y, steps):
    def scan_fun(carry, step):
        x, y = carry
        z = f(fun, x, y)
        carry = x * step, y * step
        return carry, z
    init_carry = (x, y)
    
    carry, output = jax.lax.scan(scan_fun, init_carry, steps)
    return carry, output

carry, output = fun2(fun, 2., 3., jnp.array([2.0, 2.0]))

def loss(fun, x, y, steps):
    carry, output = fun2(fun, x, y, steps)
    return output.sum()

with jax.checking_leaks():
    jit(grad(loss, argnums = (1)))(fun, 2., 3., jnp.array([2.0, 2.0])) # Leaked trace error!
    # eqx.filter_jit(grad(loss, argnums = (1)))(fun, 2., 3., jnp.array([2.0, 2.0])) # works for function f = sin


@ToshiyukiBandai
Copy link
Author

Okay, one way to get around is to use eqx.field(static=True) for the param attribute in the class. However, I don't prefer to do this because it requires a lot of work to change my codes...

class Fun(eqx.Module):
    param: jnp.array = eqx.field(static=True)
    def __call__(self, x):
        return self.param * jnp.sin(x) 

@ToshiyukiBandai
Copy link
Author

It turned out that using partial seems to be a good way:

grad_loss = grad(ft.partial(loss, fun))
jit(grad_loss)(2., 3., jnp.array([2.0, 2.0]))

@patrick-kidger
Copy link
Owner

The issue is this bit" @ft.partial(custom_jvp, nondiff_argnums=(0,)).

You've marked an array (inside Fun) as not-an-array, which is the leak.

Try using eqx.filter_custom_jvp instead.

@patrick-kidger patrick-kidger added the question User queries label Jun 4, 2024
@ToshiyukiBandai
Copy link
Author

Thank you for the suggestion! I modified the custom vjp rule below, but I got an error when evaluating the gradient "error "ValueError: Received keyword tangent". How can I pass the function fun without using the keyward argument?

import functools as ft
import jax
from jax import jit, grad, lax, custom_jvp
import jax.numpy as jnp
import equinox as eqx

print(eqx.__version__) # 0.11.4
print(jax.__version__) # 0.4.28

platform = 'cpu'
# platform = 'gpu'
jax.config.update('jax_platform_name', platform)
print(jax.devices())

class Fun(eqx.Module):
    param: jnp.array
    def __call__(self, x):
        return self.param * jnp.sin(x) 

# fun = jnp.sin # this is okay

fun = Fun(jnp.array(1.0))

fun(jnp.array(1.0))

@eqx.filter_custom_jvp
def f(x, y, *, fun):
    return fun(x)*y

@f.def_jvp
def f_jvp(primals, tangents, *, fun):
    x, y = primals
    x_dot, y_dot = tangents
    primal_out = f(x, y, fun = fun)
    tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
    return primal_out, tangent_out

f(jnp.array(2.0), jnp.array(3.0), fun = fun)

grad_f = eqx.filter_jit(eqx.filter_grad(f))

grad_f(jnp.array(2.0), jnp.array(3.0), fun = fun) # error "ValueError: Received keyword tangent"

eqx.filter_grad(f)(jnp.array(2.0), jnp.array(3.0), fun = fun) # error "ValueError: Received keyword tangent"

@lockwo
Copy link
Contributor

lockwo commented Jun 5, 2024

Also the example in the docs doesn't seem to be very robust,

@eqx.filter_custom_jvp
def call(x, y, *, fn):
    return fn(x, y)

@call.def_jvp
def call_jvp(primals, tangents, *, fn):
    x, y = primals
    tx, ty = tangents
    primal_out = call(x, y, fn=fn)
    tangent_out = tx**2 + ty
    return primal_out, tangent_out

eqx.filter_grad(call)(jnp.array(2.0), jnp.array(3.0), fn = lambda a, b: a + b)

fails. (the function is random, none I tried worked)

[<ipython-input-15-40164c8d4502>](https://localhost:8080/#) in call_jvp(primals, tangents, fn)
      8     tx, ty = tangents
      9     primal_out = call(x, y, fn=fn)
---> 10     tangent_out = tx**2 + ty
     11     return primal_out, tangent_out
     12 

TypeError: unsupported operand type(s) for +: 'JaxprTracer' and 'NoneType'

patrick-kidger added a commit that referenced this issue Jun 7, 2024
See #745 (comment).

Also improves the documentation into a larger example, to help make clear why some tangent may be `None`.
@patrick-kidger
Copy link
Owner

Thanks both!
@ToshiyukiBandai -- you're right, this is a spurious crash. This should be fixed in #749.
@lockwo -- the behaviour you're seeing is expected, and is due to ty having a symbolic zero tangent. (It's not being differentiated by equinox.filter_grad, as that only differentiates its first argument.) I've updated the docs in #749 to help make clear what's going on!

patrick-kidger added a commit that referenced this issue Jun 9, 2024
See #745 (comment).

Also improves the documentation into a larger example, to help make clear why some tangent may be `None`.
@ToshiyukiBandai
Copy link
Author

Hi @patrick-kidger, Thank you for checking and fixing the reference. I ran the example but got an error like below. The error is gone if I changed tx**2 to tx.

NotImplementedError                       Traceback (most recent call last)
Cell In[6], line 16
     14 fn = lambda a, b: a + b
     15 # This only computes gradients for the first argument `x`.
---> 16 eqx.filter_grad(call)(x, y, fn=fn)

File [~/miniforge3/envs/jax/lib/python3.9/site-packages/equinox/_ad.py:96](http://localhost:8888/lab/tree/LBNL/JAX_FastTraining/notebooks/~/miniforge3/envs/jax/lib/python3.9/site-packages/equinox/_ad.py#line=95), in _GradWrapper.__call__(self, *args, **kwargs)
     95 def __call__(self, /, *args, **kwargs):
---> 96     value, grad = self._fun_value_and_grad(*args, **kwargs)
     97     if self._has_aux:
     98         _, aux = value

File [~/miniforge3/envs/jax/lib/python3.9/site-packages/equinox/_ad.py:79](http://localhost:8888/lab/tree/LBNL/JAX_FastTraining/notebooks/~/miniforge3/envs/jax/lib/python3.9/site-packages/equinox/_ad.py#line=78), in _ValueAndGradWrapper.__call__(self, *args, **kwargs)
     77 x, *args = args
     78 diff_x, nondiff_x = partition(x, is_inexact_array)
---> 79 return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)

    [... skipping hidden 7 frame]

File [~/miniforge3/envs/jax/lib/python3.9/site-packages/jax/_src/interpreters/ad.py:285](http://localhost:8888/lab/tree/LBNL/JAX_FastTraining/notebooks/~/miniforge3/envs/jax/lib/python3.9/site-packages/jax/_src/interpreters/ad.py#line=284), in get_primitive_transpose(p)
    283   return primitive_transposes[p]
    284 except KeyError as err:
--> 285   raise NotImplementedError(
    286       "Transpose rule (for reverse-mode differentiation) for '{}' "
    287       "not implemented".format(p)) from err

NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'integer_pow' not implemented

@patrick-kidger
Copy link
Owner

patrick-kidger commented Jun 10, 2024

Whoops! You're completely correct. I must have tweaked that after I checked it ran. Thank you for the heads up -- now fixed in #754!

@ToshiyukiBandai
Copy link
Author

@patrick-kidger
Thank you for fixing the doc. I was able to run the example with the new equinox. But, I have a question. Does the current @eqx.filter_custom_jvp assume the second positional argument (y in call in the example in the doc) to be non-differentiable? In the example below (almost same as the one I provided above), it seems like the second positional argument is treated as non-differentiable, which I did not intend to do and resulted in a TypeError. Did I miss anything?

import functools as ft
import jax
from jax import jit, grad, lax, custom_jvp
import jax.numpy as jnp
import equinox as eqx

print(eqx.__version__) # 0.11.4

print(jax.__version__) # 0.4.29
class Fun(eqx.Module):
    param: jnp.array
    def __call__(self, x):
        return self.param * jnp.sin(x) 

# fun = jnp.sin # this is okay

fun = Fun(jnp.array(1.0))

fun(jnp.array(1.0))

@eqx.filter_custom_jvp
def f(x, y, *, fun):
    return fun(x)*y

@f.def_jvp
def f_jvp(primals, tangents, *, fun):
    x, y = primals
    x_dot, y_dot = tangents
    primal_out = f(x, y, fun = fun)
    tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
    return primal_out, tangent_out

eqx.filter_jit(eqx.filter_grad(f))(jnp.array(2.0), jnp.array(3.0), fun = fun) # TypeError: unsupported operand type(s) for *: 'DynamicJaxprTracer' and 'NoneType'

@patrick-kidger
Copy link
Owner

In this case it is the filter_grad that is doing this. This differentiates the first argument to the function it wraps.

(If you need to differentiate multiple things then wrap them together into a tuple; more generally into any PyTree.)

@ToshiyukiBandai
Copy link
Author

@patrick-kidger
Thank you for your quick response. Actually, I had the same error even when using grad as below.

jit(grad(f))(jnp.array(2.0), jnp.array(3.0), fun = fun) # TypeError: unsupported operand type(s) for *: 'DynamicJaxprTracer' and 'NoneType'

Interestingly, if I evaluated the derivative of both arguments, it worked

grad(f, argnums=(0, 1))(jnp.array(2.0), jnp.array(3.0), fun = fun)

I guess something is going on but have no clue to it...

@patrick-kidger
Copy link
Owner

Your JVP has the line jnp.sin(x) * y_dot. The tangents passed to equinox.filter_custom_jvp-wrapped JVP can be either an array, or a None. The latter corresponds to a symbolic zero; i.e. when a value depends only on undifferentiated inputs.

In this case y is not being differentiated, so its tangent is a symbolic zero, which is a None, and so you see the error.

Note that all of the above is true regardless of whether you juse jax.grad or equinox.filter_grad -- everything you're seeing here is a property of equinox.filter_custom_jvp.

@ToshiyukiBandai
Copy link
Author

Understood, thank you for the explanation!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants