-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance regression with JAX 0.4.32 on CPU #3
Comments
Can you dump HLO somewhere. One known problem is that small while loops got a lot more expensive, because we basically replaced compiler with interpreter, see a workaround I submitted for CPU backend: jax-ml/jax@15d4389 Linux I'll try to reproduce it on my side, but with HLO dump I can do a lot easier. |
You can dump HLO by setting
|
Thanks for the quick feedback!
produces hlo_default.zip.
produces hlo_no_thunk.zip. I'll try to look into Edit: removed |
Edited to add: This comment was meant for a different issue, not this one. |
Very interesting! Thanks a lot for looking into it. I am not (?) using tanh in the code, afaik. Is tanh generated by the backend, maybe fusing some operations? |
Oops. Sorry about that! That comment was meant for a different issue (jax-ml/jax#23590). I posted on the wrong tab. Thank you for the HLO dump! I see one |
Ok, thanks for the clarification! Still I was a bit surprised: I am not using
|
jax(lib) 0.4.33 just dropped. Performance is slightly worse than 0.4.32:
old CPU backend is same as before:
|
@penpornk would you prefer I post this as a jax issue? |
Yes, there were issues with 0.4.32 so the wheel was pulled off PyPI and 0.4.33 was a quick re-release.
Up to you. For the XLA:CPU team, the bug being here doesn't make a big difference. We have created an internal bug within our team and @pparuzel is looking at this issue. For future bugs, it would be better to post on https://github.com/google/jax or https://github.com/openxla/xla, for more visibility. (And it could help JAX folks keep track of things they want to include in a new release.) |
Looks like 92.7% of time is spent in That would probably narrow down to:
I need to keep on digging to find the exact bottleneck. |
The performance impact seems to be somewhat inconsistent. This is a table of the ratio
Runtime on the old i7-8550U is generally higher, unsurprisingly. But the relative difference between old and new runtime is much smaller on the new CPU i5-1345U. Could this be an effect only of cache sizes etc.? |
We are noticing the new runtime generates less spmd instructions for this particular case. Therefore, the cache size might indeed explain the difference between these architectures. However, the root cause of the slowdown is still to be found. Currently, the suspicion is that the codegen is missing some crucial LLVM metadata semantics which discourage optimizations through spmd in the thunk runtime. |
There is a performance regression in BP on CPU from JAX 0.4.31 to 0.4.32. The reason seems to be the new CPU backend with increased concurrency (jax-ml/jax#23590).
Default behavior in JAX 0.4.32:
vs. manually deactivated CPU concurrency:
BP takes 679ms vs 445ms.
The text was updated successfully, but these errors were encountered: