-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Re-enabling scan_f mode for fori_loop and helpfulness of error message #2956
Comments
See #2414 (comment) for why this is disabled. I'll update the error message for now at least. |
Also, thanks for this very clear issue! |
cc @mattjj |
@mattjj do we have a way to verify if re-enabling it causes problems? |
This issue prevents reverse-mode differentiation of the matrix exponential (eg. @araza6 this may be the source of the error when you backprop through |
@zerodynamics Thanks a lot for your response. Do you think |
expm is used in the M Layer paper and it would be cool to try that Intelligent Matrix Exponentiation |
I think we fixed this at some point :) |
I am not adding an example to reproduce since its quite conceptual and easy to follow through what is going on (theres even 2 respective TODOs in the source, so this is more of a reminder):
Most similar example to what I am doing https://github.com/google/jax/blob/5a0bf46234481887d18f1c8623c8a78d4a2a842e/examples/control.py#L68 . I need a
fori_loop
that can be backpropagated through i.e. one that usesscan
rather thanlax.while_loop
.When trying to use
fori_loop
making sure the bounds were constant (static with respect to jit, i.e. convertable to int via int(x)) , I get the following error message :Its not very clear what constant means (specially since I thought I was already ensuring that). Having a quick search for this online I couldnt find anything useful , using the provided keywords like constant. Checking the source code I can see that constant relates to the condition of being casted to int . Theres even a related TODO on improving the error messages : https://github.com/google/jax/blob/5a0bf46234481887d18f1c8623c8a78d4a2a842e/jax/lax/lax_control_flow.py#L149
The type checking is not written in the most clear way and what is a bit more confusing is that its redudnant at the moment: (https://github.com/google/jax/blob/5a0bf46234481887d18f1c8623c8a78d4a2a842e/jax/lax/lax_control_flow.py#L159)
Is there any reason why this cant be enabled ? set use_scan to true if typecheck for constants bounds passes ? I spent quite some time debugging / navigating this .
Sure a temporary solution is to just write my own
fori_loop
that usesscan
or take it from here https://github.com/google/jax/blob/5a0bf46234481887d18f1c8623c8a78d4a2a842e/examples/control.py#L68 , but it would be nice if this worked , I locally enabled the use_scan flag in the scope of the else clause and that works fine for my current use case.The text was updated successfully, but these errors were encountered: