Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

from_numpyro ran out of memory when the model is run on gpu #2307

Open
yaowang1111 opened this issue Jan 25, 2024 · 3 comments
Open

from_numpyro ran out of memory when the model is run on gpu #2307

yaowang1111 opened this issue Jan 25, 2024 · 3 comments

Comments

@yaowang1111
Copy link

I have a numpyro model run on a 4090 gpu, the result is fine.
But when i try to visualize it using from_numpyro, i got this our memory error.
2024-01-25 19:54:08.798078: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 13.73GiB (rounded to 14745600000)requested by op
2024-01-25 19:54:08.798443: W external/tsl/tsl/framework/bfc_allocator.cc:497] *******************************************************************************_____________________
2024-01-25 19:54:08.798689: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 14745600000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 802.3KiB
constant allocation: 33B
maybe_live_out allocation: 13.73GiB
preallocated temp allocation: 7.92MiB
total allocation: 13.74GiB
Peak buffers:
Buffer 1:
Size: 13.73GiB
Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/tmp/ipykernel_1024906/1472702232.py" source_line=4
XLA Label: fusion
Shape: f32[500,18,640,640]
==========================

Buffer 2:
	Size: 7.03MiB
	Operator: op_name="jit(scan)/jit(main)/while/body/broadcast_in_dim[shape=(18, 640, 640) broadcast_dimensions=()]" source_file="/tmp/ipykernel_1024906/1290945412.py" source_line=67
	XLA Label: fusion
	Shape: pred[18,640,640]
	==========================

Buffer 3:
	Size: 540.0KiB
	Operator: op_name="jit(scan)/jit(main)/while[cond_nconsts=0 body_nconsts=26]" source_file="/tmp/ipykernel_1024906/1472702232.py" source_line=4
	XLA Label: fusion
	Shape: f32[12,640,18]
	==========================

Buffer 4:
	Size: 540.0KiB
	Operator: op_name="jit(scan)/jit(main)/while[cond_nconsts=0 body_nconsts=26]" source_file="/tmp/ipykernel_1024906/1472702232.py" source_line=4
	Entry Parameter Subshape: f32[640,18,12]
	==========================

Buffer 5:
	Size: 45.0KiB
	Operator: op_name="jit(scan)/jit(main)/while/body/slice[start_indices=(0, 0, 10) limit_indices=(640, 18, 11) strides=None]" source_file="/tmp/ipykernel_1024906/1290945412.py" source_line=13
	XLA Label: fusion
	Shape: f32[1,18,640]
	==========================

Buffer 6:
	Size: 45.0KiB
	Operator: op_name="jit(scan)/jit(main)/while/body/slice[start_indices=(0, 0, 10) limit_indices=(640, 18, 11) strides=None]" source_file="/tmp/ipykernel_1024906/1290945412.py" source_line=13
	XLA Label: fusion
	Shape: f32[1,640,18]
	==========================

Buffer 7:
	Size: 45.0KiB
	Operator: op_name="jit(scan)/jit(main)/while/body/slice[start_indices=(0, 0, 10) limit_indices=(640, 18, 11) strides=None]" source_file="/tmp/ipykernel_1024906/1290945412.py" source_line=13
	XLA Label: fusion
	Shape: f32[1,640,18]
	==========================

Buffer 8:
	Size: 45.0KiB
	Operator: op_name="jit(scan)/jit(main)/while/body/slice[start_indices=(0, 0, 10) limit_indices=(640, 18, 11) strides=None]" source_file="/tmp/ipykernel_1024906/1290945412.py" source_line=13
	XLA Label: fusion
	Shape: f32[1,640,18]
	==========================

Buffer 9:
	Size: 45.0KiB
	Operator: op_name="jit(scan)/jit(main)/while/body/slice[start_indices=(0, 0, 10) limit_indices=(640, 18, 11) strides=None]" source_file="/tmp/ipykernel_1024906/1290945412.py" source_line=13
	XLA Label: fusion
	Shape: f32[1,640,18]
	==========================

Buffer 10:
	Size: 45.0KiB
	Operator: op_name="jit(scan)/jit(main)/while/body/slice[start_indices=(0, 0, 10) limit_indices=(640, 18, 11) strides=None]" source_file="/tmp/ipykernel_1024906/1290945412.py" source_line=13
	XLA Label: fusion
	Shape: f32[1,640,18]
	==========================

Buffer 11:
	Size: 45.0KiB
	XLA Label: fusion
	Shape: f32[18,640]
	==========================

Buffer 12:
	Size: 45.0KiB
	Operator: op_name="jit(scan)/jit(main)/while/body/broadcast_in_dim[shape=(18, 640) broadcast_dimensions=()]" source_file="/tmp/ipykernel_1024906/1290945412.py" source_line=67
	XLA Label: fusion
	Shape: f32[18,640]
	==========================

Buffer 13:
	Size: 45.0KiB
	Entry Parameter Subshape: s32[640,18]
	==========================

Buffer 14:
	Size: 23.4KiB
	Entry Parameter Subshape: f32[500,12]
	==========================

Buffer 15:
	Size: 23.4KiB
	Entry Parameter Subshape: f32[500,12]
	==========================

I searched the document of JAX and found JAX will allocate 70 percent of gpu mem by default before running, it seems that arviz are allocating mem for this huge spare space.

@ahartikainen
Copy link
Contributor

Is there a way to move stuff to cpu in numpyro side?

@yaowang1111
Copy link
Author

@ahartikainen I tried, but it seems that mcmc object is not a jax object, so jax.put_device() doesn't know how to deal with it.

@OriolAbril
Copy link
Member

reviewing issues and seeing this now I realized it might be attempting to allocate the pointwise log likelihood values that generates that error. Does this issue happen both when using az.from_numpyro(..., log_likelihood=False) and az.from_numpyro(..., log_likelihood=True)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants