Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix pow_p jvp rule at x=0. y=0 #16419

Merged
merged 2 commits into from
Aug 23, 2023
Merged

fix pow_p jvp rule at x=0. y=0 #16419

merged 2 commits into from
Aug 23, 2023

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Jun 15, 2023

fixes #14397

For autodiff purposes (and eventually for evaluation implementation purposes) we need to distinguish between pow :: inexact -> int -> inexact (which is differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which isn't); see #14397 (comment).

Instead of making a new primitive, we made the old one polymorphic and made it switch its behavior on the element type of its second argument.

@mattjj mattjj requested a review from froystig June 15, 2023 01:38
fixes jax-ml#14397

For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see jax-ml#14397 (comment).

Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.

There were also some other cases with special handling for algorithmic reasons
(e.g. doing binary exponentiation), so these autodiff cases had to be merged
with those algorithmic cases.

Co-authored-by: Roy Frostig <frostig@google.com>
@mattjj mattjj added pull ready Ready for copybara import and testing and removed pull ready Ready for copybara import and testing labels Aug 22, 2023
@copybara-service copybara-service bot merged commit af42359 into jax-ml:main Aug 23, 2023
7 checks passed
@mattjj mattjj deleted the pow-jvp branch August 23, 2023 00:22
mattjj added a commit to mattjj/jax that referenced this pull request Oct 7, 2023
alhridoy pushed a commit to alhridoy/jax that referenced this pull request Oct 20, 2023
alhridoy pushed a commit to alhridoy/jax that referenced this pull request Oct 20, 2023
fixes jax-ml#17995

Add precision warning and workaround to jnp.arange documentation

Remove mistakenly added file and update jnp.arange documentation

Update jnp.arange documentation based on feedback

CLA Issue

Update arange documentation formatting as per review feedback

Update arange documentation formatting as per review feedback

Update lax_numpy.py merge conflict resolved

fix pow jvp rule with int exponent (broken since jax-ml#16419)

fixes jax-ml#17995
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

jax.grad computes incorrect derivative for polynomials
3 participants