Replies: 3 comments 5 replies
-
CPU is probably not fusing it. But on TPU, this is what the optimized HLO looks like:
You should look at the optimized HLO for fusions! On CPU I can't see the fusion but on TPU you can. |
Beta Was this translation helpful? Give feedback.
-
Interestingly, I think I get fusion on the cpu if my parameter array is <= 32 elements -
but with 33 or higher:
This doesn't make much sense to me. |
Beta Was this translation helpful? Give feedback.
-
One more bit of info - the jaxpr and the stablehlo (jit.lower.as_text) are identical for both the 32-array and 33-array (aside from the tensor dimension). The only difference appears in the compiled version (jit.lower.compile.as_text) with one having the fusion operator and one not. Does this indicate a potential problem with XLA as opposed to JAX? |
Beta Was this translation helpful? Give feedback.
-
Hello! What is the most efficient way to do reductions so that data is passed over only once?
The following code:
produces the following HLO:
This implies to me that the data is essentially iterated over once to compute an array of cosine values and then again to do the reduction which is inefficient. Is that two-pass understanding correct? If so, what is the way to do a reduction plus computation in one pass?
Incidentally, I noticed that a version I tried with jax.lax.reduce seems to have a "reducer" block instead of the stablehlo.reduce line - is there any difference between these two as far the number of passes over the data?
I'm using the CPU backend if that's relevant.
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions