Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance degradation of fftn with the new CPU runtime #25808

Open
roth-jakob opened this issue Jan 9, 2025 · 4 comments
Open

Performance degradation of fftn with the new CPU runtime #25808

roth-jakob opened this issue Jan 9, 2025 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@roth-jakob
Copy link
Contributor

Description

With the new CPU runtime, I experience a performance degradation of jnp.fft.fftn on. Below is a reproducer with a switch to change between the old and new CPU runtime by setting --xla_cpu_use_thunk_runtime=false:

# switch between new and old CPU runtime
old_cpu_runtime = True

if old_cpu_runtime:
    import os
    XLA_flag = "--xla_cpu_use_thunk_runtime=false "
    print(f'set: {XLA_flag}')
    os.environ["XLA_FLAGS"] = XLA_flag

import jax
import jax.numpy as jnp
import timeit

def timeit_like(stmt, number=100, globals=None):
    execution_time = timeit.timeit(stmt, number=number, globals=globals)
    print(f"{stmt} executed in: {execution_time:.5f} seconds (for {number} runs)")

def fft(x):
    return jnp.fft.fftn(x)

fft_jit = jax.jit(fft)

shape = (200, 200, 200)

# warmup
inp = jnp.full(shape, 1.1 + 1.1j)
fft_jit(inp)

# benchmarks
timeit_like("jax.block_until_ready(fft(inp))", globals=globals())
timeit_like("jax.block_until_ready(fft_jit(inp))", globals=globals())

For the old CPU runtime (old_cpu_runtime = True) I obtain:

set: --xla_cpu_use_thunk_runtime=false 
jax.block_until_ready(fft(inp)) executed in: 2.37158 seconds (for 100 runs)
jax.block_until_ready(fft_jit(inp)) executed in: 2.21418 seconds (for 100 runs)

With the new CPU runtime (old_cpu_runtime = False) I get:

jax.block_until_ready(fft(inp)) executed in: 6.82972 seconds (for 100 runs)
jax.block_until_ready(fft_jit(inp)) executed in: 6.69958 seconds (for 100 runs)

Thus, the new CPU runtime degrades the performance of the fftn by approximately a factor of 3.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.38
jaxlib: 0.4.38
numpy: 2.2.1
python: 3.12.8 | packaged by conda-forge | (main, Dec 5 2024, 14:24:40) [GCC 13.3.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='tux', release='6.12.8-arch1-1', version='#1 SMP PREEMPT_DYNAMIC Thu, 02 Jan 2025 22:52:26 +0000', machine='x86_64')

@roth-jakob roth-jakob added the bug Something isn't working label Jan 9, 2025
@dfm
Copy link
Collaborator

dfm commented Jan 9, 2025

Thanks for the report!

Pinging @ezhulenev and @penpornk who will have the most useful context.

I can't reproduce this performance degradation on Colab, but I find a performance hit of about 4x with the thunks runtime on my macbook. Interestingly, this factor remains similar even with larger problem sizes. The HLO is simple (just a single FFT op), so I'm not totally sure what would be causing this.

@dfm dfm self-assigned this Jan 9, 2025
@roth-jakob
Copy link
Contributor Author

roth-jakob commented Jan 9, 2025

Thanks for the quick response! Interesting that this is system-dependent. On two Linux latops with AMD and Intel CPU I consistently get the slowdown with the new thunks runtime.

@matteani
Copy link

I tried benchmarking this on a M1 2020 macbook pro and also got a slowdown

With the old CPU runtime:

set: --xla_cpu_use_thunk_runtime=false 
jax.block_until_ready(fft(inp)) executed in: 1.32537 seconds (for 100 runs)
jax.block_until_ready(fft_jit(inp)) executed in: 1.32489 seconds (for 100 runs)

With the new CPU runtime:

jax.block_until_ready(fft(inp)) executed in: 4.27288 seconds (for 100 runs)
jax.block_until_ready(fft_jit(inp)) executed in: 4.20558 seconds (for 100 runs)

@Edenhofer
Copy link
Contributor

Could this issue be escalated please? This steep performance degradation affects almost all models coded with NIFTy and this bug is the sole reason by now why we conflict newer JAX versions in NIFTy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants