-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.grad computes incorrect derivative for polynomials #14397
Comments
Thanks for raising this. It looks like we can also repro with just |
Ah, now I remember... see #12033. We used to have I'm not sure how to fix this yet! |
Okay, thanks to @dougalm for figuring much of this out and @hawkinsp for explaining some of the current code. My current understanding is that inside
This bug is arising because we are accidentally conflating these functions. In particular, we should differentiate function (1) differently from function (2) or (3), since the former is differentiable at (In addition to autodiff, it may make sense to disentangle these for performance reasons.) To fix this bug, we ultimately need to have different autodiff behavior depending on which function we're working with, which we can infer from the types of the arguments. That's my plan, though I'm not sure yet how to organize the code (i.e. whether to make on primitive which handles differently-typed inputs, or instead just to have the first function be non-primitive, since there's no XLA HLO op for it anyway). |
fixes jax-ml#14397 For autodiff purposes (and eventually for evaluation implementation purposes) we need to distinguish between pow :: inexact -> int -> inexact (which is differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which isn't); see jax-ml#14397 (comment). Instead of making a new primitive, we made the old one polymorphic and switch its behavior on the element type of its second argument. Co-authored-by: Roy Frostig <frostig@google.com>
fixes jax-ml#14397 For autodiff purposes (and eventually for evaluation implementation purposes) we need to distinguish between pow :: inexact -> int -> inexact (which is differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which isn't); see jax-ml#14397 (comment). Instead of making a new primitive, we made the old one polymorphic and switch its behavior on the element type of its second argument. Co-authored-by: Roy Frostig <frostig@google.com>
fixes jax-ml#14397 For autodiff purposes (and eventually for evaluation implementation purposes) we need to distinguish between pow :: inexact -> int -> inexact (which is differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which isn't); see jax-ml#14397 (comment). Instead of making a new primitive, we made the old one polymorphic and switch its behavior on the element type of its second argument. There were also some other cases with special handling for algorithmic reasons (e.g. doing binary exponentiation), so these autodiff cases had to be merged with those algorithmic cases. Co-authored-by: Roy Frostig <frostig@google.com> fixes
fixes jax-ml#14397 For autodiff purposes (and eventually for evaluation implementation purposes) we need to distinguish between pow :: inexact -> int -> inexact (which is differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which isn't); see jax-ml#14397 (comment). Instead of making a new primitive, we made the old one polymorphic and switch its behavior on the element type of its second argument. There were also some other cases with special handling for algorithmic reasons (e.g. doing binary exponentiation), so these autodiff cases had to be merged with those algorithmic cases. Co-authored-by: Roy Frostig <frostig@google.com>
fixes jax-ml#14397 For autodiff purposes (and eventually for evaluation implementation purposes) we need to distinguish between pow :: inexact -> int -> inexact (which is differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which isn't); see jax-ml#14397 (comment). Instead of making a new primitive, we made the old one polymorphic and switch its behavior on the element type of its second argument. There were also some other cases with special handling for algorithmic reasons (e.g. doing binary exponentiation), so these autodiff cases had to be merged with those algorithmic cases. Co-authored-by: Roy Frostig <frostig@google.com>
fixes jax-ml#14397 For autodiff purposes (and eventually for evaluation implementation purposes) we need to distinguish between pow :: inexact -> int -> inexact (which is differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which isn't); see jax-ml#14397 (comment). Instead of making a new primitive, we made the old one polymorphic and switch its behavior on the element type of its second argument. There were also some other cases with special handling for algorithmic reasons (e.g. doing binary exponentiation), so these autodiff cases had to be merged with those algorithmic cases. Co-authored-by: Roy Frostig <frostig@google.com>
fixes jax-ml#14397 For autodiff purposes (and eventually for evaluation implementation purposes) we need to distinguish between pow :: inexact -> int -> inexact (which is differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which isn't); see jax-ml#14397 (comment). Instead of making a new primitive, we made the old one polymorphic and switch its behavior on the element type of its second argument. There were also some other cases with special handling for algorithmic reasons (e.g. doing binary exponentiation), so these autodiff cases had to be merged with those algorithmic cases. Co-authored-by: Roy Frostig <frostig@google.com>
Anyone has a temporary workaround for this? |
There's a workaround mentioned in #14745 (comment). Though I'm kind of embarrassed this issue is still open... thanks for the ping. |
This issue should be fixed at github HEAD! We haven't updated the pypi version yet, but if you install from github the issue should be gone. |
@mattjj Finally got around to checking this, and I'm finding that this bug is not fixed in v0.4.16.dev20230901. To reproduce, just do
Thanks for looking into this. |
Try
As per the above comment, the function on In particular, if you want to write polynomials, use int-like numeric types for the powers. What do you think? |
Hi @mattjj, thanks for this. I'm finding this choice very confusing, since at the back of my head, I'm dealing with a function in one variable, x, and so it's unexpected that for every expression I write down inside such a function, I need to check to make sure what I think are constants are well-behaved if they were to become variables. But perhaps there are good reasons why this should be. I'm still running into trouble along these lines though, I think when
|
Hey @hongwanliu , sorry I didn't notice your message until now. (Don't hesitate to make new issues; they're more visible!) That seems very weird... |
If you change the b = lambda z: jnp.sum(z**jnp.arange(0, 2, dtype=float))
d = lambda z: 1 + z
print(jax.grad(d)(1.)) # gives 1. as expected
print(jax.grad(b)(1.)) # gives 1. also I'd say that's pretty surprising though. We probably need to revise this; either there's some bug, or this dtype-based-resolving-of-ambiguities is too subtle and we should raise an error instead (asking the user to be explicit). |
I opened #17995, let's track there. |
Description
jax.grad
does not handle constant in a polynomial correctly and results innan
when differentiating at 0.Here is an example where differentiating
x^2 + x + 1
at 0 results innan
The issue is that the derivative subtracts 1 from all exponents and results in computing the expression
0 * 1/x
To illustrate this, here is the resulting jax expression:
I also noticed that this bug does not exist in earlier version of jax (I checked jax 0.2.10 w/ jaxlib 0.1.62).
What jax/jaxlib version are you using?
jax 0.4.3, jaxlib 0.4.3
Which accelerator(s) are you using?
CPU
Additional system info
No response
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: