You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
With the new CPU runtime, I experience a performance degradation of jnp.fft.fftn on. Below is a reproducer with a switch to change between the old and new CPU runtime by setting --xla_cpu_use_thunk_runtime=false:
# switch between new and old CPU runtime
old_cpu_runtime = True
if old_cpu_runtime:
import os
XLA_flag = "--xla_cpu_use_thunk_runtime=false "
print(f'set: {XLA_flag}')
os.environ["XLA_FLAGS"] = XLA_flag
import jax
import jax.numpy as jnp
import timeit
def timeit_like(stmt, number=100, globals=None):
execution_time = timeit.timeit(stmt, number=number, globals=globals)
print(f"{stmt} executed in: {execution_time:.5f} seconds (for {number} runs)")
def fft(x):
return jnp.fft.fftn(x)
fft_jit = jax.jit(fft)
shape = (200, 200, 200)
# warmup
inp = jnp.full(shape, 1.1 + 1.1j)
fft_jit(inp)
# benchmarks
timeit_like("jax.block_until_ready(fft(inp))", globals=globals())
timeit_like("jax.block_until_ready(fft_jit(inp))", globals=globals())
For the old CPU runtime (old_cpu_runtime = True) I obtain:
I can't reproduce this performance degradation on Colab, but I find a performance hit of about 4x with the thunks runtime on my macbook. Interestingly, this factor remains similar even with larger problem sizes. The HLO is simple (just a single FFT op), so I'm not totally sure what would be causing this.
Thanks for the quick response! Interesting that this is system-dependent. On two Linux latops with AMD and Intel CPU I consistently get the slowdown with the new thunks runtime.
Could this issue be escalated please? This steep performance degradation affects almost all models coded with NIFTy and this bug is the sole reason by now why we conflict newer JAX versions in NIFTy.
Description
With the new CPU runtime, I experience a performance degradation of
jnp.fft.fftn
on. Below is a reproducer with a switch to change between the old and new CPU runtime by setting--xla_cpu_use_thunk_runtime=false
:For the old CPU runtime (
old_cpu_runtime = True
) I obtain:With the new CPU runtime (
old_cpu_runtime = False
) I get:Thus, the new CPU runtime degrades the performance of the
fftn
by approximately a factor of 3.System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.38
jaxlib: 0.4.38
numpy: 2.2.1
python: 3.12.8 | packaged by conda-forge | (main, Dec 5 2024, 14:24:40) [GCC 13.3.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='tux', release='6.12.8-arch1-1', version='#1 SMP PREEMPT_DYNAMIC Thu, 02 Jan 2025 22:52:26 +0000', machine='x86_64')
The text was updated successfully, but these errors were encountered: