Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix jacfwd #734

Merged
merged 2 commits into from
May 23, 2024
Merged

Fix jacfwd #734

merged 2 commits into from
May 23, 2024

Conversation

lockwo
Copy link
Contributor

@lockwo lockwo commented May 21, 2024

I just found that jacfwd breaks if there are no static outputs from the Jacobian. E.g. it fails on the simple neural network on the home page. This fixes that (without the change, the new tests would fail, they would encounter an error because eqx.combine is call on a pytree and a None object)

@patrick-kidger
Copy link
Owner

patrick-kidger commented May 21, 2024

Actually, I think it's more complicated than this. (Possibly even undefined?) The jacobian should be something that is shaped as output x input, but here static_out is something that is shaped according to (the static part of) just output.

I think we need to define these operations in a different way altogether. WDYT?

@lockwo
Copy link
Contributor Author

lockwo commented May 21, 2024

Hmmm, maybe a silly question but what's the actual need for static output? The input is already parsed to be dynamic and the output of the function has to be array types (already if you output a string or something it will fail, since jax.jacfwd doesn't know how to interact with it). Like suppose a simple example where one would probably expect jax and equinox to be the same (i.e. all jax types, no filtering actually needed)

def cond_fun(val):
    x = val
    return jnp.mean(x) <= 5

def body_fun(val):
    x = val
    return x**2

def f(x):
    return jax.lax.while_loop(cond_fun, body_fun, x), 1.0

x = jnp.array([1.0, 2.0])
print("eqx", eqx.filter_jacfwd(f)(x))
print("jax", jax.jacfwd(f)(x))

will yield

eqx (Array([[ 4.,  0.],
       [ 0., 32.]], dtype=float64), 1.0)
jax (Array([[ 4.,  0.],
       [ 0., 32.]], dtype=float64), Array([0., 0.], dtype=float64, weak_type=True))

(this is independent of my PR, since this PR only effects if _static_out is None, but _static_out is not None here).

If we just get ride of static out and return out directly (e.g. return _dynamic_out, (_static_out, _aux) -> return _out, (None, _aux)) then we see

eqx (Array([[ 4.,  0.],
       [ 0., 32.]], dtype=float64), Array([0., 0.], dtype=float64, weak_type=True))
jax (Array([[ 4.,  0.],
       [ 0., 32.]], dtype=float64), Array([0., 0.], dtype=float64, weak_type=True))

@patrick-kidger
Copy link
Owner

Yup. This is what I mean about probably needing to change the definition of filter_jac{fwd,rev}.

The resulting Jacobian is a "nested PyTree", whose structure/array-shapes are formed by composing the PyTrees of the outputs and inputs. But indeed if we have a string as both input and output, then there is no obvious way to compose these.

I suspect we can define this either for the case when the inputs may have static elements, or for the case where the outputs have static elements, but not both.

@lockwo
Copy link
Contributor Author

lockwo commented May 21, 2024

Yes, that makes sense. Personally, I lean towards just not allowing/supporting outputs to have static elements (seems like if you needed static outputs then there is a way for the user to combine them at the end). Static inputs are ubiquitous and seem essential to support.

@patrick-kidger
Copy link
Owner

SGTM!

@lockwo
Copy link
Contributor Author

lockwo commented May 22, 2024

Alright, I removed the static filtering and added a warning about it in the docs

@patrick-kidger patrick-kidger merged commit d4f6a0e into patrick-kidger:main May 23, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

LGTM! Thank you for spotting this, and for the fix! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants