-
-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix jacfwd #734
Fix jacfwd #734
Conversation
Actually, I think it's more complicated than this. (Possibly even undefined?) The jacobian should be something that is shaped as I think we need to define these operations in a different way altogether. WDYT? |
Hmmm, maybe a silly question but what's the actual need for static output? The input is already parsed to be dynamic and the output of the function has to be array types (already if you output a string or something it will fail, since jax.jacfwd doesn't know how to interact with it). Like suppose a simple example where one would probably expect jax and equinox to be the same (i.e. all jax types, no filtering actually needed) def cond_fun(val):
x = val
return jnp.mean(x) <= 5
def body_fun(val):
x = val
return x**2
def f(x):
return jax.lax.while_loop(cond_fun, body_fun, x), 1.0
x = jnp.array([1.0, 2.0])
print("eqx", eqx.filter_jacfwd(f)(x))
print("jax", jax.jacfwd(f)(x)) will yield
(this is independent of my PR, since this PR only effects if _static_out is None, but _static_out is not None here). If we just get ride of static out and return out directly (e.g.
|
Yup. This is what I mean about probably needing to change the definition of The resulting Jacobian is a "nested PyTree", whose structure/array-shapes are formed by composing the PyTrees of the outputs and inputs. But indeed if we have a string as both input and output, then there is no obvious way to compose these. I suspect we can define this either for the case when the inputs may have static elements, or for the case where the outputs have static elements, but not both. |
Yes, that makes sense. Personally, I lean towards just not allowing/supporting outputs to have static elements (seems like if you needed static outputs then there is a way for the user to combine them at the end). Static inputs are ubiquitous and seem essential to support. |
SGTM! |
Alright, I removed the static filtering and added a warning about it in the docs |
LGTM! Thank you for spotting this, and for the fix! :) |
I just found that jacfwd breaks if there are no static outputs from the Jacobian. E.g. it fails on the simple neural network on the home page. This fixes that (without the change, the new tests would fail, they would encounter an error because eqx.combine is call on a pytree and a None object)