You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a numpyro model run on a 4090 gpu, the result is fine.
But when i try to visualize it using from_numpyro, i got this our memory error.
2024-01-25 19:54:08.798078: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 13.73GiB (rounded to 14745600000)requested by op
2024-01-25 19:54:08.798443: W external/tsl/tsl/framework/bfc_allocator.cc:497] *******************************************************************************_____________________
2024-01-25 19:54:08.798689: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 14745600000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 802.3KiB
constant allocation: 33B
maybe_live_out allocation: 13.73GiB
preallocated temp allocation: 7.92MiB
total allocation: 13.74GiB
Peak buffers:
Buffer 1:
Size: 13.73GiB
Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/tmp/ipykernel_1024906/1472702232.py" source_line=4
XLA Label: fusion
Shape: f32[500,18,640,640]
==========================
I searched the document of JAX and found JAX will allocate 70 percent of gpu mem by default before running, it seems that arviz are allocating mem for this huge spare space.
The text was updated successfully, but these errors were encountered:
reviewing issues and seeing this now I realized it might be attempting to allocate the pointwise log likelihood values that generates that error. Does this issue happen both when using az.from_numpyro(..., log_likelihood=False) and az.from_numpyro(..., log_likelihood=True)?
I have a numpyro model run on a 4090 gpu, the result is fine.
But when i try to visualize it using from_numpyro, i got this our memory error.
2024-01-25 19:54:08.798078: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 13.73GiB (rounded to 14745600000)requested by op
2024-01-25 19:54:08.798443: W external/tsl/tsl/framework/bfc_allocator.cc:497] *******************************************************************************_____________________
2024-01-25 19:54:08.798689: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 14745600000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 802.3KiB
constant allocation: 33B
maybe_live_out allocation: 13.73GiB
preallocated temp allocation: 7.92MiB
total allocation: 13.74GiB
Peak buffers:
Buffer 1:
Size: 13.73GiB
Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/tmp/ipykernel_1024906/1472702232.py" source_line=4
XLA Label: fusion
Shape: f32[500,18,640,640]
==========================
I searched the document of JAX and found JAX will allocate 70 percent of gpu mem by default before running, it seems that arviz are allocating mem for this huge spare space.
The text was updated successfully, but these errors were encountered: