Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong results on CPU since 0.4.32 #23590

Closed
dionhaefner opened this issue Sep 12, 2024 · 9 comments
Closed

Wrong results on CPU since 0.4.32 #23590

dionhaefner opened this issue Sep 12, 2024 · 9 comments
Labels
bug Something isn't working

Comments

@dionhaefner
Copy link

Description

I'm seeing test failures in Veros when bumping JAX to 0.4.32.

There appear to be significant deviations from the expected results (which are computed by a Fortran reference). All tests are executed on CPU and pass for previous versions of JAX.

K_33:
    Not equal to tolerance rtol=1e-08, atol=1e-10

    Mismatched elements: 2133 / 210000 (1.02%)
    Max absolute difference: 1.01966039e+15
    Max relative difference: 291.09569767
     x: array([[[ 0.000000e+00,  0.000000e+00,  0.000000e+00, ...,
             -0.000000e+00, -1.866951e+12,  0.000000e+00],
            [ 0.000000e+00,  0.000000e+00,  0.000000e+00, ...,...
     y: array([[[ 0.000000e+00,  0.000000e+00,  0.000000e+00, ...,
              0.000000e+00,  0.000000e+00,  0.000000e+00],
            [ 0.000000e+00,  0.000000e+00,  0.000000e+00, ...,...
Ai_ez:
    Not equal to tolerance rtol=1e-08, atol=1e-10

    Mismatched elements: 16596 / 840000 (1.98%)
    Max absolute difference: 0.20468597
    Max relative difference: 0.875
     x: array([[[[[-1.648106e+00,  0.000000e+00],
              [ 2.606564e-01,  0.000000e+00]],
    ...
     y: array([[[[[-1.648106e+00,  0.000000e+00],
              [ 2.606564e-01,  0.000000e+00]],
    ...
Ai_nz:
    Not equal to tolerance rtol=1e-08, atol=1e-10

    Mismatched elements: 16574 / 840000 (1.97%)
    Max absolute difference: 0.16398475
    Max relative difference: 0.85714286
     x: array([[[[[-3.398974e-01,  0.000000e+00],
              [ 2.441893e-01,  0.000000e+00]],
    ...
     y: array([[[[[-3.398974e-01,  0.000000e+00],
              [ 2.441893e-01,  0.000000e+00]],
    ...
Ai_bx:
    Not equal to tolerance rtol=1e-08, atol=1e-10

    Mismatched elements: 8356 / 840000 (0.995%)
    Max absolute difference: 0.20468597
    Max relative difference: 0.875
     x: array([[[[[ 0.000000e+00,  0.000000e+00],
              [ 0.000000e+00,  0.000000e+00]],
    ...
     y: array([[[[[-0.000000e+00,  0.000000e+00],
              [ 0.000000e+00, -0.000000e+00]],
    ...
Ai_by:
    Not equal to tolerance rtol=1e-08, atol=1e-10

    Mismatched elements: 8354 / 840000 (0.995%)
    Max absolute difference: 0.16398475
    Max relative difference: 0.85714286
     x: array([[[[[ 0.000000e+00,  0.000000e+00],
              [ 0.000000e+00,  0.000000e+00]],
    ...
     y: array([[[[[ 0.000000e+00,  0.000000e+00],
              [ 0.000000e+00, -0.000000e+00]],
    ...

I've tried setting jax.config.update('jax_cpu_enable_async_dispatch', False) (since that's the only thing I saw in the changelog that I thought may be related) but it made no difference.

Anything else I could try on my end? Asking for wild guesses here, because going through the motions to isolate the problem is really nontrivial (these are complicated kernels consisting of 1000s of SLOC).

System info (python version, jaxlib version, accelerator, etc.)

Python 3.12, JAX 0.4.32, CPU

@dionhaefner dionhaefner added the bug Something isn't working label Sep 12, 2024
@dfm
Copy link
Collaborator

dfm commented Sep 12, 2024

Another CPU change in v0.4.32 is that some of the LAPACK wrappers in lax.linalg have been updated. It might be useful to isolate which linear algebra operations are used in the kernels to see if there's a bug there? If you can let me know which, if any, linear algebra operations you're using, I'm happy to dig on our side!

@hawkinsp
Copy link
Collaborator

Another thing to try: does XLA_FLAGS=--xla_cpu_use_thunk_runtime=false help?

There was a major change to the implementation of the CPU backend. Notably, we'll use more concurrency on CPU. If that fixes things, please share a reproduction.

@hawkinsp
Copy link
Collaborator

We actually just yanked 0.4.32 because of a TPU problem, but if you can get a reproducer it'd be great to look into this.

@dionhaefner
Copy link
Author

Thanks both for the suggestions. The kernel in question doesn't use linear algebra, and the suggested XLA flags didn't make a difference.

I managed to dig up a reproducer. You can run this:

$ git clone git@github.com:dionhaefner/pyhpc-benchmarks.git
$ cd pyhpc-benchmarks
$ python run.py benchmarks/isoneutral_mixing/ --device cpu -b jax -b numpy -s 4096 

With JAX 0.4.32 this prints

Warning: inconsistent results for size 4096

but not for older JAX versions.

This is the kernel that's being run, with random input arrays:

JAX version, NumPy version

@dionhaefner
Copy link
Author

dionhaefner commented Sep 13, 2024

Even simpler, you can run this script in the benchmarks/isoneutral_mixing folder:
(no dependencies besides jax)

from isoneutral_numpy import run as run_numpy
from isoneutral_jax import run as run_jax

import numpy as np
import jax
jax.config.update("jax_enable_x64", True)


def generate_inputs(size):
    import math

    np.random.seed(17)

    shape = (
        math.ceil(2 * size ** (1 / 3)),
        math.ceil(2 * size ** (1 / 3)),
        math.ceil(0.25 * size ** (1 / 3)),
    )

    # masks
    maskT, maskU, maskV, maskW = (
        (np.random.rand(*shape) < 0.8).astype("float64") for _ in range(4)
    )

    # 1d arrays
    dxt, dxu = (np.random.randn(shape[0]) for _ in range(2))
    dyt, dyu = (np.random.randn(shape[1]) for _ in range(2))
    dzt, dzw, zt = (np.random.randn(shape[2]) for _ in range(3))
    cost, cosu = (np.random.randn(shape[1]) for _ in range(2))

    # 3d arrays
    K_iso, K_11, K_22, K_33 = (np.random.randn(*shape) for _ in range(4))

    # 4d arrays
    salt, temp = (np.random.randn(*shape, 3) for _ in range(2))

    # 5d arrays
    Ai_ez, Ai_nz, Ai_bx, Ai_by = (np.zeros((*shape, 2, 2)) for _ in range(4))

    return (
        maskT,
        maskU,
        maskV,
        maskW,
        dxt,
        dxu,
        dyt,
        dyu,
        dzt,
        dzw,
        cost,
        cosu,
        salt,
        temp,
        zt,
        K_iso,
        K_11,
        K_22,
        K_33,
        Ai_ez,
        Ai_nz,
        Ai_bx,
        Ai_by,
    )

testinputs = generate_inputs(1000)

def test_run():
    inputs_np = [x.copy() for x in testinputs]
    inputs_jax = [jax.numpy.asarray(x) for x in testinputs]

    out_numpy = run_numpy(*inputs_np)
    out_jax = run_jax(*inputs_jax)

    for x_np, x_jax in zip(out_numpy, out_jax):
        np.testing.assert_allclose(x_np, x_jax)

if __name__ == "__main__":
    test_run()

@penpornk
Copy link
Collaborator

Thank you for the code! I was able to reproduce the issue. Our team will look into this soon.

@penpornk
Copy link
Collaborator

This is because of my f64 Tanh approximation commit: openxla/xla@ae96f6e

I've temporarily disabled the feature in openxla/xla@8fcf359. The change is now in JAX nightly 20240914 and newer:

pip install -U --pre jax==0.4.33.dev20240914 jaxlib==0.4.33.dev20240914 -f \
    https://storage.googleapis.com/jax-releases/jax_nightly_releases.html

I've verified that the script doesn't get numerical errors anymore with the new nightly wheel. The original benchmark also passed.

python run.py benchmarks/isoneutral_mixing/ --device cpu -b jax -b numpy -s 4096 

I'll also check that both the script and the benchmark run fine before re-enabling the fast f64 Tanh approximation.

@hawkinsp
Copy link
Collaborator

I'm going to cherry-pick this change into the jax v0.4.33 release that I'm about to make.

@hawkinsp
Copy link
Collaborator

We just released JAX 0.4.33, which includes the fix for this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants