Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array constants are generated on the host during AD #7093

Closed
gnecula opened this issue Jun 24, 2021 · 3 comments · Fixed by #7102
Closed

Array constants are generated on the host during AD #7093

gnecula opened this issue Jun 24, 2021 · 3 comments · Fixed by #7102
Assignees
Labels
enhancement New feature or request

Comments

@gnecula
Copy link
Collaborator

gnecula commented Jun 24, 2021

I noticed that the function abstract_arrays.zeros_like_shaped_array has code like np.broadcast_to(np.array(0, aval.dtype), aval.shape) (using regular NumPy). These values will be computed on the host during tracing and then may be passed to the device. Is this the desired behavior?

For example, when we take gradients of constant functions:

x = np.ones((10, 10), dtype=np.float32)
jax.make_jaxpr(jax.grad(lambda x: 42.))(x)

we get

{ lambda a ; b.
  let 
  in (a,) }

where a 10x10 array of 0s is built during tracing and passed to the function as the array a.
Wouldn't it be better to compute the array of 0s directly on the device?

{ lambda  ; a.
   let  b = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(10, 10) ] 0.0
  in (b,) }

Incidentally, this affects things like jax2tf where we may end up with large constants embedded in the resulting graph, when instead their computation can be embedded in the graph.

@gnecula gnecula added the bug Something isn't working label Jun 24, 2021
@gnecula gnecula assigned gnecula and mattjj and unassigned gnecula Jun 24, 2021
@gnecula gnecula added enhancement New feature or request and removed bug Something isn't working labels Jun 24, 2021
@gnecula gnecula changed the title Array constants are generate on the host during AD for integer functions Array constants are generated on the host during AD for integer functions Jun 24, 2021
@mattjj
Copy link
Collaborator

mattjj commented Jun 24, 2021

This was intended at one point, pre-lazy-sublanguage #1668 and pre-omnistaging #3370, since there is special logic for handling broadcasted numpy constants efficiently in XLA lowering. (Hence why it's written with np.broadcast_to, which uses stride tricks for the broadcasted representation, unlike np.zeros!) But that special handling is only present for XLA lowering, and not jax2tf lowering.

I think we don't need this logic anymore. There might be edge cases where it's more efficient in eager mode, but I can't think of one, and they're probably not that interesting.

I agree we should replace it with jnp operations, since that will benefit jax2tf! I can take that on.

@gnecula
Copy link
Collaborator Author

gnecula commented Jun 25, 2021

Thanks, I can also do it.

@gnecula gnecula changed the title Array constants are generated on the host during AD for integer functions Array constants are generated on the host during AD Jun 25, 2021
@gnecula gnecula self-assigned this Jun 25, 2021
gnecula added a commit to gnecula/jax that referenced this issue Jun 25, 2021
gnecula added a commit to gnecula/jax that referenced this issue Jun 25, 2021
gnecula added a commit to gnecula/jax that referenced this issue Jun 25, 2021
Fixes: jax-ml#7093
Also fixes type checking in jax2tf, because now we have to be careful
about computations that have result float0 (the broadcast_in_dim used
to compute the zeros).
gnecula added a commit to gnecula/jax that referenced this issue Jun 25, 2021
Fixes: jax-ml#7093
Also fixes type checking in jax2tf, because now we have to be careful
about computations that have result float0 (the broadcast_in_dim used
to compute the zeros).
gnecula added a commit to gnecula/jax that referenced this issue Jun 25, 2021
Fixes: jax-ml#7093
Also fixes type checking in jax2tf, because now we have to be careful
about computations that have result float0 (the broadcast_in_dim used
to compute the zeros).
@gnecula
Copy link
Collaborator Author

gnecula commented Jun 25, 2021

@mattjj I have started on this in #7102. The actual change to use lax instead of np is very small, but there are interactions with jax2tf (fixed) and with numpy 3.6.13 (not yet fixed). I can continue to work on this, just FYI, so that you don't start afresh.

gnecula added a commit to gnecula/jax that referenced this issue Jul 25, 2021
Fixes: jax-ml#7093
Also fixes type checking in jax2tf, because now we have to be careful
about computations that have result float0 (the broadcast_in_dim used
to compute the zeros).
gnecula added a commit to gnecula/jax that referenced this issue Jul 25, 2021
Fixes: jax-ml#7093
Also fixes type checking in jax2tf, because now we have to be careful
about computations that have result float0 (the broadcast_in_dim used
to compute the zeros).
gnecula added a commit to gnecula/jax that referenced this issue Jul 26, 2021
Fixes: jax-ml#7093
Also fixes type checking in jax2tf, because now we have to be careful
about computations that have result float0 (the broadcast_in_dim used
to compute the zeros).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants