Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-enabling scan_f mode for fori_loop and helpfulness of error message #2956

Closed
franciscovargas opened this issue May 4, 2020 · 9 comments
Closed
Assignees
Labels
enhancement New feature or request

Comments

@franciscovargas
Copy link

franciscovargas commented May 4, 2020

I am not adding an example to reproduce since its quite conceptual and easy to follow through what is going on (theres even 2 respective TODOs in the source, so this is more of a reminder):

Most similar example to what I am doing https://github.com/google/jax/blob/5a0bf46234481887d18f1c8623c8a78d4a2a842e/examples/control.py#L68 . I need a fori_loop that can be backpropagated through i.e. one that uses scan rather than lax.while_loop.

When trying to use fori_loop making sure the bounds were constant (static with respect to jit, i.e. convertable to int via int(x)) , I get the following error message :

raise ValueError("Reverse-mode differentiation does not work for lax.while_loop. "
                   "Try using lax.scan, or lax.fori_loop with constant bounds.")

Its not very clear what constant means (specially since I thought I was already ensuring that). Having a quick search for this online I couldnt find anything useful , using the provided keywords like constant. Checking the source code I can see that constant relates to the condition of being casted to int . Theres even a related TODO on improving the error messages : https://github.com/google/jax/blob/5a0bf46234481887d18f1c8623c8a78d4a2a842e/jax/lax/lax_control_flow.py#L149

The type checking is not written in the most clear way and what is a bit more confusing is that its redudnant at the moment: (https://github.com/google/jax/blob/5a0bf46234481887d18f1c8623c8a78d4a2a842e/jax/lax/lax_control_flow.py#L159)

try:
    lower_ = int(lower)
    upper_ = int(upper)
  except TypeError:
    use_scan = False
  else:
    use_scan = False  # TODO(mattjj): re-enable this

Is there any reason why this cant be enabled ? set use_scan to true if typecheck for constants bounds passes ? I spent quite some time debugging / navigating this .

Sure a temporary solution is to just write my own fori_loop that uses scan or take it from here https://github.com/google/jax/blob/5a0bf46234481887d18f1c8623c8a78d4a2a842e/examples/control.py#L68 , but it would be nice if this worked , I locally enabled the use_scan flag in the scope of the else clause and that works fine for my current use case.

@skye
Copy link
Member

skye commented May 6, 2020

See #2414 (comment) for why this is disabled. I'll update the error message for now at least.

@skye
Copy link
Member

skye commented May 6, 2020

Also, thanks for this very clear issue!

@skye
Copy link
Member

skye commented May 6, 2020

cc @mattjj

@mattjj
Copy link
Collaborator

mattjj commented May 6, 2020

Nice find! Since custom_transforms is now deprecated (as of #2026 which we merged 4 days after I made that comment on #2414), maybe we can now re-enable the lowering to scan.

@skye
Copy link
Member

skye commented May 6, 2020

@mattjj do we have a way to verify if re-enabling it causes problems?

@zerodynamics
Copy link

zerodynamics commented Jul 22, 2020

This issue prevents reverse-mode differentiation of the matrix exponential (eg. grad(f(expm(X))) . A lax.fori_loop is called in _squaring (which is called by expm)

https://github.com/google/jax/blob/0a3a5bbb163641c85d599fc7f640ecc172dbbfbf/jax/scipy/linalg.py#L299-L305

@araza6 this may be the source of the error when you backprop through make_unitary

@araza6
Copy link

araza6 commented Jul 23, 2020

@zerodynamics Thanks a lot for your response. Do you think lax.scan wold help? Becuase I suppose JAX's expm differentiation is still WIP as @mattjj point out here

@bionicles
Copy link
Contributor

bionicles commented Aug 19, 2020

expm is used in the M Layer paper and it would be cool to try that

Intelligent Matrix Exponentiation
https://arxiv.org/abs/2008.03936
https://github.com/google-research/google-research/tree/master/m_layer

@mattjj
Copy link
Collaborator

mattjj commented Jul 23, 2024

I think we fixed this at some point :)

@mattjj mattjj closed this as completed Jul 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

7 participants