Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unused vmap GPU memory allocation causes RESOURCE_EXHAUSTED for versions >0.4.14 #23548

Open
pwithams opened this issue Sep 10, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@pwithams
Copy link

Description

Overview

The script below works when using an NVIDIA GPU with Jax version 0.4.14, but after upgrading to 0.4.31 (and trying a few other versions in between) it is triggering the following error:

E0910 20:24:00.097739 38257 pjrt_stream_executor_client.cc:3067] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate X bytes
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate X bytes.

where the value of X ranges from ~5GB (e.g. 4843897104) to 20GB+ depending on the shape of the dls variable (set to 3540 in the script below).

jax<=0.4.14 - no error
jax>0.4.14 - error

Not sure if this is a bug or if there is some code/syntax in the example below that is no longer supported in versions > 0.4.14 that is responsible for this behavior.

Allocation vs. pprof usage

The GPU has 6GB of memory and after some trial and error it appears that setting the dls variable to a shape of 1590 succeeds and uses only ~500kB of memory according to pprof (following https://jax.readthedocs.io/en/latest/device_memory_profiling.html), but a shape of 1600 gives an error trying to allocate ~5GB. If pprof is in fact showing GPU memory usage this could suggest memory is being allocated but not used.

jnp.exp removal

Trial and error also showed that removing the jnp.exp calls inside the function m seem to resolve the issue. For example, the script below with dls shape set to 10000 fails trying to allocate 30GB, but removing the jnp.exp calls succeeds and shows as using only ~2MB by pprof.

Script

import jax
import jax.numpy as jnp
from jax import vmap


def wp(y0, ts, rng, tidx_offset):
    t0, t1 = ts[0], ts[-1]
    y = jnp.ones((11, 3)) * (t0 * t1)
    y = jnp.vstack(
        (
            y,
            jnp.ones((71, 3)) * y0,
        )
    )
    y = jnp.roll(y, tidx_offset, axis=0)
    y = y[:71]
    y = y.at[:, 2].set(jnp.abs(y[:, 2] - 0.03) + 0.03)
    return y


def ps(ys, ts, tidx_offset):
    t = jnp.maximum(ts - ts[0], 0)
    t = jnp.hstack(
        (
            t,
            jnp.zeros(71 - 11),
        )
    )
    t = jnp.roll(t, tidx_offset, axis=0)
    t = t[:71]
    ds = jnp.sqrt(jnp.sum((ys[1:, :] - ys[:-1, :]) ** 2, axis=-1))
    d = jnp.cumsum(jnp.hstack((jnp.array([0.0]), ds)), axis=-1)

    s_xyz = jnp.array([0.123, 0.345, 0.456])
    s_xyz = jnp.exp(jnp.array(-2.0)) + s_xyz
    scale = t * jnp.exp(jnp.array(-2.2)) + d * jnp.exp(jnp.array(-2.5)) + 1e-6
    s = jnp.einsum("i,x->ix", scale, s_xyz)

    return s


def m(s, d, d_mirror, rate):
    scale = 0.5 * (15.0 / 75)
    m = (
        scale
        * rate
        / ((2 * jnp.pi) ** (3 / 2) * jnp.prod(s))
        * (
            # removing these two jnp.exp calls appears to resolve the issue
            (jnp.exp(-0.5 * jnp.sum(d**2 / s**2)))
            + (jnp.exp(-0.5 * jnp.sum(d_mirror**2 / s**2)))
        )
    )
    return m


def func(y0, dl, tss, rng, rate, tidx_offset):
    ys = wp(y0, tss, rng, tidx_offset)
    d = ys - dl
    A = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]])
    ys_mirror = jnp.matmul(ys, A)
    d_mirror = ys_mirror - dl
    scale = 0.5 * (15.0 / 75)
    rate = rate[tidx_offset]
    s = jnp.ones((71, 3)) * y0 * 2.3

    results = vmap(m, in_axes=(0, 0, 0, None))(s, d, d_mirror, rate)
    return results


@jax.jit
def run():
    y0s = jnp.ones(shape=(1, 3))
    # dls shape of ~1600+ fails on 6GB GPU trying to allocate 5GB+
    # dls shape of <1590 succeeds on 6GB GPU and uses only ~476kB memory according to pprof
    # dls shape of 10000 fails trying to allocate 30GB, but passes and only uses ~2MB when removing jnp.exp calls above
    dls = jnp.ones(shape=(3540, 3))
    rates = jnp.ones(shape=(1, 71))
    rngs = jnp.ones(shape=(71, 75, 2), dtype="uint32")
    tss = jnp.ones(shape=(71, 11, 75))
    tidx_offsets = jnp.arange(len(tss))

    output = vmap(
        vmap(
            vmap(
                vmap(func, in_axes=(None, None, 1, 0, None, None)),
                in_axes=(None, None, 0, 0, None, 0),
            ),
            in_axes=(None, 0, None, None, None, None),
        ),
        in_axes=(0, None, None, None, 0, None),
    )(y0s, dls, tss, rngs, rates, tidx_offsets)
    result = jnp.sum(output, axis=(0, 2, 3))
    return result


result = run()
jax.profiler.save_device_memory_profile("memory.prof")
print(result)
print(result.shape)

System info (python version, jaxlib version, accelerator, etc.)

Pip versions:

# jax
jax==0.4.31
jax-cuda12-pjrt==0.4.31
jax-cuda12-plugin==0.4.31
jaxlib==0.4.31
jaxtyping==0.2.34
# nvidia
nvidia-cublas-cu12==12.6.1.4
nvidia-cuda-cupti-cu12==12.6.68
nvidia-cuda-nvcc-cu12==12.6.68
nvidia-cuda-runtime-cu12==12.6.68
nvidia-cudnn-cu12==9.3.0.75
nvidia-cufft-cu12==11.2.6.59
nvidia-cusolver-cu12==11.6.4.69
nvidia-cusparse-cu12==12.5.3.3
nvidia-nccl-cu12==2.22.3
nvidia-nvjitlink-cu12==12.6.68

Output of jax.print_environment_info(), it is running inside a container based on nvidia/cuda:12.3.2-base-ubuntu22.04:

jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='docker-desktop', release='5.15.153.1-microsoft-standard-WSL2', version='#1 SMP Fri Mar 29 23:14:13 UTC 2024', machine='x86_64')


$ nvidia-smi
Tue Sep 10 20:44:01 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.06              Driver Version: 545.92       CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4050 ...    On  | 00000000:01:00.0 Off |                  N/A |
| N/A   43C    P3              11W /  35W |     78MiB /  6141MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     46151      C   /python3.10                               N/A      |
+---------------------------------------------------------------------------------------+

Pip versions of latest version that does not show the error (v0.4.14):

# jax 
jax==0.4.14
jaxlib==0.4.14+cuda12.cudnn89
jaxtyping==0.2.23
# nvidia
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvcc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.1.105
@pwithams pwithams added the bug Something isn't working label Sep 10, 2024
@justinjfu
Copy link
Collaborator

I checked the HLO when using dls=jnp.ones(shape=(10000, 3)) but it does indeed look like some very large tensors are being generated by your program (1 x 10000 x 71 x 75 x 71 x3 ~= 40GB)

ENTRY main.152 {
  constant.27 = f32[] constant(1)
  broadcast.28 = f32[1,71]{1,0} broadcast(constant.27), dimensions={}
  iota.29 = s32[71]{0} iota(), iota_dimension=0
  ...
  constant.15 = f32[3,3]{1,0} constant({ { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, -1 } })
  dot.95 = f32[1,71,75,71,3]{4,3,2,1,0} dot(scatter.89, constant.15), lhs_contracting_dims={4}, rhs_contracting_dims={0}
  reshape.96 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} reshape(dot.95)
  broadcast.97 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.96), dimensions={0,1,2,3,4,5}
  reshape.98 = f32[1,71,75,71,3]{4,3,2,1,0} reshape(broadcast.97)
  broadcast.99 = f32[1,10000,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.98), dimensions={0,2,3,4,5}
  subtract.100 = f32[1,10000,71,75,71,3]{5,4,3,2,1,0} subtract(broadcast.99, broadcast.17)
  multiply.132 = f32[1,10000,71,75,71,3]{5,4,3,2,1,0} multiply(subtract.100, subtract.100)
  divide.133 = f32[1,10000,71,75,71,3]{5,4,3,2,1,0} divide(multiply.132, broadcast.4)
  reduce.138 = f32[1,10000,71,75,71]{4,3,2,1,0} reduce(divide.133, constant.25), dimensions={5}, to_apply=region_3.134
  multiply.139 = f32[1,10000,71,75,71]{4,3,2,1,0} multiply(reduce.138, broadcast.2)
  exponential.140 = f32[1,10000,71,75,71]{4,3,2,1,0} exponential(multiply.139)
  add.141 = f32[1,10000,71,75,71]{4,3,2,1,0} add(exponential.131, exponential.140)
  multiply.146 = f32[1,10000,71,75,71]{4,3,2,1,0} multiply(broadcast.145, add.141)
  ROOT reduce.151 = f32[10000,71]{1,0} reduce(multiply.146, constant.25), dimensions={0,2,3}, to_apply=region_4.147
}

After commenting out the two lines containing exp these large tensors are not materialized:

  ...
  constant.12 = f32[] constant(1)
  reduce.27 = f32[1,71]{1,0} reduce(broadcast.6, constant.12), dimensions={2}, to_apply=region_0.23
  constant.1 = f32[] constant(15.7496099)
  broadcast.2 = f32[1,71]{1,0} broadcast(constant.1), dimensions={}
  multiply.28 = f32[1,71]{1,0} multiply(reduce.27, broadcast.2)
  reshape.29 = f32[1,1,71]{2,1,0} reshape(multiply.28)
  broadcast.34 = f32[1,1,71]{2,1,0} broadcast(reshape.29), dimensions={0,1,2}
  reshape.35 = f32[1,71]{1,0} reshape(broadcast.34)
  broadcast.36 = f32[1,71,71]{2,1,0} broadcast(reshape.35), dimensions={0,2}
  divide.37 = f32[1,71,71]{2,1,0} divide(broadcast.33, broadcast.36)
  broadcast.38 = f32[1,10000,71,75,71]{4,3,2,1,0} broadcast(divide.37), dimensions={0,2,4}
  constant.11 = f32[] constant(0)
  ROOT reduce.43 = f32[10000,71]{1,0} reduce(broadcast.38, constant.11), dimensions={0,2,3}, to_apply=region_1.39
}

I'm not sure why thus code runs on Jax <0.4.14... it's possible there's some optimizations being done differently. You can inspect the compiled code yourself using:
run.lower().compiler_ir(dialect='hlo').as_hlo_text() (for >=0.4.30)
jax.xla_computation(run)().as_hlo_text() (for <0.4.30)

@pwithams
Copy link
Author

Thanks for the response. I'm starting to think it is some change in openxla or lower that is responsible rather than jax itself. A few questions:

  • what part of the hlo text shows when a tensor is "materialized"? is there docs/links on how to read these outputs?
  • what's the difference between func.lower().as_text() and run.lower().compiler_ir(dialect='hlo').as_hlo_text()?
  • how do you determine expected memory based on tensor shape?

Does this seem like a bug or just an old edge case not working anymore do you think? When using dls=jnp.ones(shape=(1590, 3)) the program ran successfully and pprof reported ~500kB of memory usage, but increasing to dls=jnp.ones(shape=(1600, 3)) fails trying to allocate ~5GB, which seems like strange behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants