Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix convert_element_type on large Python int inputs #6165

Merged
merged 3 commits into from
Mar 22, 2021

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Mar 21, 2021

@jekbradbury noticed that when we call lax.convert_element_type(2 ** 100, jnp.float32), we first convert to int32 when we canonicalize the input dtype. But int32 can't represent 2**100, while a float32 can!

Concretely, before this PR, this fails:

import jax.numpy as jnp
jnp.multiply(2 ** 100, 3.)  # OverflowError: Python int too large to convert to C long

It fails because under the hood it ends up doing this, which also fails:

from jax import lax
lax.convert_element_type(2 ** 100, jnp.float32)

even though this succeeds:

import numpy as np
lax.convert_element_type(np.float32(2 ** 100), jnp.float32)

#6014 caused this to surface in a downstream library, because the special-case logic #6014 removed had effectively called np.array(x, to_dtype) before applying any JAX primitives.

Luckily these issues lead to loud overflow errors from NumPy, rather than a silent loss of bits.

The fix is just to have lax.convert_element_type (i.e. the 'traceable' wrapper) use NumPy to convert Python int inputs to numpy arrays with the target dtype (like a float32), rather than the current behavior of converting to the canonical dtype for the input (like an int32), before transferring the value out of Python and to the device.

I also had to fix up some handling of float0s in host_callback logic. The logic now mirrors the analogous logic in custom_jvp/custom_vjp. I tweaked the implementation of the fix so that these aren't needed anymore. They may still be good changes, but I'd rather keep the PR minimal.

@google-cla google-cla bot added the cla: yes label Mar 21, 2021
@mattjj mattjj changed the title fix convert_element_type on large inputs fix convert_element_type on large Python int inputs Mar 22, 2021
@mattjj mattjj force-pushed the convert-element-type-impl branch 2 times, most recently from 3a11b4c to 78f5c3d Compare March 22, 2021 01:05
@mattjj mattjj added the pull ready Ready for copybara import and testing label Mar 22, 2021
@copybara-service copybara-service bot merged commit 555aba8 into master Mar 22, 2021
@mattjj mattjj deleted the convert-element-type-impl branch March 22, 2021 05:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants