Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.grad computes incorrect derivative for polynomials #14397

Closed
wuxishy opened this issue Feb 10, 2023 · 12 comments · Fixed by #16419
Closed

jax.grad computes incorrect derivative for polynomials #14397

wuxishy opened this issue Feb 10, 2023 · 12 comments · Fixed by #16419
Assignees
Labels
bug Something isn't working P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)

Comments

@wuxishy
Copy link

wuxishy commented Feb 10, 2023

Description

jax.grad does not handle constant in a polynomial correctly and results in nan when differentiating at 0.
Here is an example where differentiating x^2 + x + 1 at 0 results in nan

def f(x):
    return jnp.sum(x**jnp.arange(3))
 
jax.grad(f)(0.0)
# output is Array(nan, dtype=float32, weak_type=True)

The issue is that the derivative subtracts 1 from all exponents and results in computing the expression 0 * 1/x
To illustrate this, here is the resulting jax expression:

>>> jax.make_jaxpr(jax.grad(f))(0.0)
{ lambda ; a:f32[]. let
    b:i32[3] = iota[dimension=0 dtype=int32 shape=(3,)] 
    c:f32[3] = convert_element_type[new_dtype=float32 weak_type=True] b
    d:f32[3] = pow a c
    e:f32[3] = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] 1.0
    f:f32[3] = sub c e
    g:f32[3] = pow a f
    h:f32[3] = mul c g
    i:f32[3] = convert_element_type[new_dtype=float32 weak_type=False] d
    _:f32[] = reduce_sum[axes=(0,)] i
    j:f32[3] = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] 1.0
    k:f32[3] = convert_element_type[new_dtype=float32 weak_type=True] j
    l:f32[3] = mul k h
    m:f32[] = reduce_sum[axes=(0,)] l
  in (m,) }

I also noticed that this bug does not exist in earlier version of jax (I checked jax 0.2.10 w/ jaxlib 0.1.62).

What jax/jaxlib version are you using?

jax 0.4.3, jaxlib 0.4.3

Which accelerator(s) are you using?

CPU

Additional system info

No response

NVIDIA GPU info

No response

@wuxishy wuxishy added the bug Something isn't working label Feb 10, 2023
@mattjj
Copy link
Collaborator

mattjj commented Feb 10, 2023

Thanks for raising this. It looks like we can also repro with just jax.grad(lambda x: x ** 0.)(0.).

@mattjj mattjj self-assigned this Feb 10, 2023
@mattjj
Copy link
Collaborator

mattjj commented Feb 10, 2023

Ah, now I remember... see #12033. We used to have select statements to avoid this, but that led to an incorrect second derivative with respect to the second input. See in particular this comment.

I'm not sure how to fix this yet!

@mattjj
Copy link
Collaborator

mattjj commented Feb 10, 2023

Okay, thanks to @dougalm for figuring much of this out and @hawkinsp for explaining some of the current code.

My current understanding is that inside jnp.power, aka **, there are at least two wolves functions:

  1. a float -> int -> float partial (i.e. sometimes-nan-producing) function representing the real-valued mathematical function $(x, n) \mapsto x^n$ on real bases and integer exponents except points of the form $(0, -n) ; n \in \mathbb{N}_+$, which for each fixed integer is continuous and infinitely differentiable in the domain of its first argument (with discontinuities not in its domain);
  2. a float -> float -> float partial function representing the real-valued mathematical function $(x, y) \mapsto x^y$ on positive real bases and real exponents (not continuous at $(0.0, 0.0)$ because consider approaching it along $0^y=0$ vs $x^0=1$), which is continuous and infinitely differentiable in its domain;
  3. a complex -> complex -> complex partial function representing some principal branch of the multi-valued mathematical function $(a, b) \mapsto e^{a \ln b}$ (also not continuous at $(0.0, 0.0)$).

This bug is arising because we are accidentally conflating these functions. In particular, we should differentiate function (1) differently from function (2) or (3), since the former is differentiable at $(0.0, 0)$ while the latter two are not differentiable at $(0.0, 0.0)$. I suspect the same autodiff rule can be used for (2) and (3).

(In addition to autodiff, it may make sense to disentangle these for performance reasons.)

To fix this bug, we ultimately need to have different autodiff behavior depending on which function we're working with, which we can infer from the types of the arguments. That's my plan, though I'm not sure yet how to organize the code (i.e. whether to make on primitive which handles differently-typed inputs, or instead just to have the first function be non-primitive, since there's no XLA HLO op for it anyway).

@mattjj mattjj added the P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) label Feb 11, 2023
mattjj added a commit to mattjj/jax that referenced this issue Jun 15, 2023
fixes jax-ml#14397

For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see jax-ml#14397 (comment).

Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.

Co-authored-by: Roy Frostig <frostig@google.com>
mattjj added a commit to mattjj/jax that referenced this issue Jul 10, 2023
fixes jax-ml#14397

For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see jax-ml#14397 (comment).

Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.

Co-authored-by: Roy Frostig <frostig@google.com>
mattjj added a commit to mattjj/jax that referenced this issue Jul 27, 2023
fixes jax-ml#14397

For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see jax-ml#14397 (comment).

Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.

There were also some other cases with special handling for algorithmic reasons
(e.g. doing binary exponentiation), so these autodiff cases had to be merged
with those algorithmic cases.

Co-authored-by: Roy Frostig <frostig@google.com>

fixes
mattjj added a commit to mattjj/jax that referenced this issue Jul 28, 2023
fixes jax-ml#14397

For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see jax-ml#14397 (comment).

Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.

There were also some other cases with special handling for algorithmic reasons
(e.g. doing binary exponentiation), so these autodiff cases had to be merged
with those algorithmic cases.

Co-authored-by: Roy Frostig <frostig@google.com>
mattjj added a commit to mattjj/jax that referenced this issue Jul 28, 2023
fixes jax-ml#14397

For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see jax-ml#14397 (comment).

Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.

There were also some other cases with special handling for algorithmic reasons
(e.g. doing binary exponentiation), so these autodiff cases had to be merged
with those algorithmic cases.

Co-authored-by: Roy Frostig <frostig@google.com>
mattjj added a commit to mattjj/jax that referenced this issue Jul 28, 2023
fixes jax-ml#14397

For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see jax-ml#14397 (comment).

Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.

There were also some other cases with special handling for algorithmic reasons
(e.g. doing binary exponentiation), so these autodiff cases had to be merged
with those algorithmic cases.

Co-authored-by: Roy Frostig <frostig@google.com>
mattjj added a commit to mattjj/jax that referenced this issue Jul 29, 2023
fixes jax-ml#14397

For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see jax-ml#14397 (comment).

Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.

There were also some other cases with special handling for algorithmic reasons
(e.g. doing binary exponentiation), so these autodiff cases had to be merged
with those algorithmic cases.

Co-authored-by: Roy Frostig <frostig@google.com>
@hongwanliu
Copy link

Anyone has a temporary workaround for this?

@mattjj
Copy link
Collaborator

mattjj commented Aug 22, 2023

There's a workaround mentioned in #14745 (comment). Though I'm kind of embarrassed this issue is still open... thanks for the ping.

@mattjj
Copy link
Collaborator

mattjj commented Aug 23, 2023

This issue should be fixed at github HEAD! We haven't updated the pypi version yet, but if you install from github the issue should be gone.

@hongwanliu
Copy link

hongwanliu commented Sep 2, 2023

@mattjj Finally got around to checking this, and I'm finding that this bug is not fixed in v0.4.16.dev20230901. To reproduce, just do

>>> jax.grad(lambda x: x ** 0.)(0.)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Array(nan, dtype=float32, weak_type=True)

Thanks for looking into this.

@mattjj
Copy link
Collaborator

mattjj commented Sep 2, 2023

Try jax.grad(lambda x: x ** 0)(0.) instead:

In [1]: import jax
In [2]: jax.grad(lambda x: x ** 0)(0.)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Out[2]: Array(0., dtype=float32, weak_type=True)

As per the above comment, the function on $\mathbb{R} \times \mathbb{R}$ isn't differentiable at $(0, 0)$, but the function on $\mathbb{R} \times \mathbb{Z}$ is (with respect to its first argument). We let the ** infix operator refer to either function, disambiguated by the types you give it: if you give it a pair of float-like inputs then we treat it as the former, and if you give it a float-like paired with an int-like we treat it as the latter.

In particular, if you want to write polynomials, use int-like numeric types for the powers.

What do you think?

@hongwanliu
Copy link

Hi @mattjj, thanks for this. I'm finding this choice very confusing, since at the back of my head, I'm dealing with a function in one variable, x, and so it's unexpected that for every expression I write down inside such a function, I need to check to make sure what I think are constants are well-behaved if they were to become variables. But perhaps there are good reasons why this should be.

I'm still running into trouble along these lines though, I think when arange is used to build the polynomial. Here is a simple example that breaks:

import jax
import jax.numpy as jnp

b = lambda z: jnp.sum(z**jnp.arange(0, 2))
d = lambda z: 1. + z
print(b(2.543), d(2.543)) # check they are the same and I'm not crazy
print(jax.grad(d)(1.))  # gives 1. as expected
print(jax.grad(b)(1.)) # gives 2.

@mattjj
Copy link
Collaborator

mattjj commented Oct 7, 2023

Hey @hongwanliu , sorry I didn't notice your message until now. (Don't hesitate to make new issues; they're more visible!)

That seems very weird...

@mattjj
Copy link
Collaborator

mattjj commented Oct 7, 2023

If you change the jnp.arange to use float dtype, then things agree:

b = lambda z: jnp.sum(z**jnp.arange(0, 2, dtype=float))
d = lambda z: 1 + z

print(jax.grad(d)(1.))  # gives 1. as expected
print(jax.grad(b)(1.)) # gives 1. also

I'd say that's pretty surprising though. We probably need to revise this; either there's some bug, or this dtype-based-resolving-of-ambiguities is too subtle and we should raise an error instead (asking the user to be explicit).

@mattjj
Copy link
Collaborator

mattjj commented Oct 7, 2023

I opened #17995, let's track there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants