-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Array constants are generated on the host during AD #7093
Comments
This was intended at one point, pre-lazy-sublanguage #1668 and pre-omnistaging #3370, since there is special logic for handling broadcasted numpy constants efficiently in XLA lowering. (Hence why it's written with I think we don't need this logic anymore. There might be edge cases where it's more efficient in eager mode, but I can't think of one, and they're probably not that interesting. I agree we should replace it with |
Thanks, I can also do it. |
Fixes: jax-ml#7093 Also fixes type checking in jax2tf, because now we have to be careful about computations that have result float0 (the broadcast_in_dim used to compute the zeros).
Fixes: jax-ml#7093 Also fixes type checking in jax2tf, because now we have to be careful about computations that have result float0 (the broadcast_in_dim used to compute the zeros).
Fixes: jax-ml#7093 Also fixes type checking in jax2tf, because now we have to be careful about computations that have result float0 (the broadcast_in_dim used to compute the zeros).
Fixes: jax-ml#7093 Also fixes type checking in jax2tf, because now we have to be careful about computations that have result float0 (the broadcast_in_dim used to compute the zeros).
Fixes: jax-ml#7093 Also fixes type checking in jax2tf, because now we have to be careful about computations that have result float0 (the broadcast_in_dim used to compute the zeros).
Fixes: jax-ml#7093 Also fixes type checking in jax2tf, because now we have to be careful about computations that have result float0 (the broadcast_in_dim used to compute the zeros).
I noticed that the function
abstract_arrays.zeros_like_shaped_array
has code likenp.broadcast_to(np.array(0, aval.dtype), aval.shape)
(using regular NumPy). These values will be computed on the host during tracing and then may be passed to the device. Is this the desired behavior?For example, when we take gradients of constant functions:
we get
where a 10x10 array of 0s is built during tracing and passed to the function as the array
a
.Wouldn't it be better to compute the array of 0s directly on the device?
Incidentally, this affects things like jax2tf where we may end up with large constants embedded in the resulting graph, when instead their computation can be embedded in the graph.
The text was updated successfully, but these errors were encountered: