Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance regression with JAX 0.4.32 on CPU #3

Open
clemisch opened this issue Sep 12, 2024 · 13 comments
Open

Performance regression with JAX 0.4.32 on CPU #3

clemisch opened this issue Sep 12, 2024 · 13 comments

Comments

@clemisch
Copy link
Owner

clemisch commented Sep 12, 2024

There is a performance regression in BP on CPU from JAX 0.4.31 to 0.4.32. The reason seems to be the new CPU backend with increased concurrency (jax-ml/jax#23590).

Default behavior in JAX 0.4.32:

$ python3 timing.py --bp --fp --size=128
gpu      : None
prealloc : False
pmap     : False
fp       : True
bp       : True
size     : 128
dtype    : 'float32'
==== FP ====
(128, 128, 128) -> (128, 128, 128) :  1154 ms ,  0.55 µs per pixel , 0.002 GRays/s
==== BP ====
(128, 128, 128) -> (128, 128, 128) :   679 ms ,  0.32 µs per voxel , 0.003 GRays/s
                                       ^^^

vs. manually deactivated CPU concurrency:

$ XLA_FLAGS=--xla_cpu_use_thunk_runtime=false python3 timing.py --bp --fp --size=128
gpu      : None
prealloc : False
pmap     : False
fp       : True
bp       : True
size     : 128
dtype    : 'float32'
==== FP ====
(128, 128, 128) -> (128, 128, 128) :  1231 ms ,  0.59 µs per pixel , 0.002 GRays/s
==== BP ====
(128, 128, 128) -> (128, 128, 128) :   445 ms ,  0.21 µs per voxel , 0.005 GRays/s
                                       ^^^

BP takes 679ms vs 445ms.

@ezhulenev
Copy link

Can you dump HLO somewhere. One known problem is that small while loops got a lot more expensive, because we basically replaced compiler with interpreter, see a workaround I submitted for CPU backend: jax-ml/jax@15d4389

Linux perf also should show where the CPU cycles are spent.

I'll try to reproduce it on my side, but with HLO dump I can do a lot easier.

@penpornk
Copy link

penpornk commented Sep 13, 2024

You can dump HLO by setting XLA_FLAGS=--xla_dump_to=/tmp/hlo (If you are using more than one flags, just add a space between each, e.g., XLA_FLAGS="--xla_dump_to=/tmp/hlo --xla_cpu_use_thunk_runtime=false"). Please zip all files in the dumped folder and upload them here?

while loop is one of the suspects. Another is oneDNN custom calls. They are not available in the new runtime yet so if your code has a lot of matmuls/convolutions, you may see some slowdowns.

@clemisch
Copy link
Owner Author

clemisch commented Sep 13, 2024

Thanks for the quick feedback!

$ XLA_FLAGS="--xla_dump_to=/scratch/hlo_default" python3 timing.py --bp --size=128

produces hlo_default.zip.

$ XLA_FLAGS="--xla_dump_to=/scratch/hlo_no_thunk --xla_cpu_use_thunk_runtime=false" python3 timing.py --bp --size=128

produces hlo_no_thunk.zip.

I'll try to look into perf in the meantime.


Edit: removed --fp to unclutter the HLO.

@penpornk
Copy link

penpornk commented Sep 13, 2024

Looks like this is because of my f64 tanh approximation commit:
openxla/xla@ae96f6e

I'll either fix it or temporarily disable it before JAX 0.4.32 re-releases.

Edited to add: This comment was meant for a different issue, not this one.

@clemisch
Copy link
Owner Author

clemisch commented Sep 14, 2024

Very interesting! Thanks a lot for looking into it.

I am not (?) using tanh in the code, afaik. Is tanh generated by the backend, maybe fusing some operations?

@penpornk
Copy link

I am not (?) using tanh in the code, afaik. Is tanh generated by the backend, maybe fusing some operations?

Oops. Sorry about that! That comment was meant for a different issue (jax-ml/jax#23590). I posted on the wrong tab.

Thank you for the HLO dump! I see one while loop in the main module (and no oneDNN custom calls in the no-thunk dump), so this regression may indeed be from while thunk. Our team will investigate soon. :)

@clemisch
Copy link
Owner Author

Ok, thanks for the clarification!

Still I was a bit surprised: I am not using while loops in the code used for timing.py, but scans. I was not aware that scan and while are lowered to the same HLO and (unconsciously) thinking that scan is more efficient. Now I see that the docs indeed state:

scan is a JAX primitive and is lowered to a single WhileOp

@clemisch
Copy link
Owner Author

clemisch commented Sep 17, 2024

jax(lib) 0.4.33 just dropped.

Performance is slightly worse than 0.4.32:

$ python3 timing.py --bp --size=128
gpu      : None
prealloc : False
pmap     : False
fp       : False
bp       : True
size     : 128
dtype    : 'float32'
==== BP ====
(128, 128, 128) -> (128, 128, 128) :   703 ms ,  0.34 µs per voxel , 0.003 GRays/s
                                       ^^^

old CPU backend is same as before:

$ XLA_FLAGS="--xla_cpu_use_thunk_runtime=false" python3 timing.py --bp --size=128
gpu      : None
prealloc : False
pmap     : False
fp       : False
bp       : True
size     : 128
dtype    : 'float32'
==== BP ====
(128, 128, 128) -> (128, 128, 128) :   442 ms ,  0.21 µs per voxel , 0.005 GRays/s
                                       ^^^

@clemisch
Copy link
Owner Author

@penpornk would you prefer I post this as a jax issue?

@penpornk
Copy link

jax(lib) 0.4.33 just dropped.

Yes, there were issues with 0.4.32 so the wheel was pulled off PyPI and 0.4.33 was a quick re-release.

would you prefer I post this as a jax issue?

Up to you. For the XLA:CPU team, the bug being here doesn't make a big difference. We have created an internal bug within our team and @pparuzel is looking at this issue.

For future bugs, it would be better to post on https://github.com/google/jax or https://github.com/openxla/xla, for more visibility. (And it could help JAX folks keep track of things they want to include in a new release.)

@pparuzel
Copy link

Looks like 92.7% of time is spent in fusion.clone. The WhileOp is a mere 3.6% of the total.

That would probably narrow down to:

  ROOT fusion.clone = f32[128,128,128]{2,1,0} fusion(p.1, p.2, p.3, p.4, p.5, /*index=5*/p.6, p.7, p.8), kind=kLoop, calls=fused_computation.clone, metadata={op_name="jit(get_bp)/jit(main)/while/body/add" source_file="/home/clem/git/jaxtomo/jaxtomo/projectors/cone_bp.py" source_line=71}, backend_config={"outer_dimension_partitions":["8"]}

I need to keep on digging to find the exact bottleneck.

@clemisch
Copy link
Owner Author

clemisch commented Oct 9, 2024

The performance impact seems to be somewhat inconsistent.

This is a table of the ratio time_default / time_nothunk i.e. new/old runtime. The cell Intel i7-8550U ∧ 128 is the ratio of my previous benchmark (703ms / 442ms).

--size= Intel i7-8550U Intel i5-1345U
128 1.59 (703/442) 1.21 (461/380)
256 1.60 (10625/6624) 1.07 (6479/6070)

Runtime on the old i7-8550U is generally higher, unsurprisingly. But the relative difference between old and new runtime is much smaller on the new CPU i5-1345U.

Could this be an effect only of cache sizes etc.?

@pparuzel
Copy link

pparuzel commented Oct 9, 2024

We are noticing the new runtime generates less spmd instructions for this particular case. Therefore, the cache size might indeed explain the difference between these architectures. However, the root cause of the slowdown is still to be found.

Currently, the suspicion is that the codegen is missing some crucial LLVM metadata semantics which discourage optimizations through spmd in the thunk runtime.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants