-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Looking for ways to save memory through batchify #3865
Comments
Hey, thanks for the question! NeRF is awesome.
Ah, the trouble is that
Are you using Overall, if you could share a full runnable repro that shows the memory explosions you're seeing, that would make it much easier to help! Think that's possible? |
I think what you're looking for (in JAX terms) is a map that is partially-sequential, partially-vectorized (so intermediate between |
I was thinking along similar lines, but then One thing to try is putting def batchify(fn, chunk=1024*32):
return jax.remat(lambda inputs : ...) But under a |
Hmm, it's not clear to me how that |
The repo is created! You can clone and run the code in /Code/NeRF-jax.py. More information can be found in the post. Unfortunately, adding jax.remat inside batchify did not work. |
Yes that sounds promising. I never thought that you can use lax.map with jax.vmap. I will definitely try it later. |
Didn't work in that you got an error, or in that there was no error but you're still seeing too much memory use? |
The latter. Memory use was still high. |
Thanks for your advice. Yet I am not quite familiar with lax.map. Below I have the input batchified by the number of rows in the image. For example, the image size in this case is (95, 126), so I batchified by each row_batch=5 rows of data. The loss function not changed yet.
|
Is it possible to move the definition of the objective function outside of that loop? |
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
pts_flat = jnp.reshape(pts, [-1, 3])
pts_flat = embed_fn(pts_flat, L_embed)
# def batchify(fn, chunk=1024 * 32):
# return lambda inputs: jnp.concatenate(
# [fn(inputs[i : i + chunk]) for i in range(0, inputs.shape[0], chunk)], 0,
# )
# raw = batchify(net_fn)(pts_flat)
raw = lax.map(net_fn, jnp.reshape(pts_flat, [-1, batch_size, pts_flat.shape[-1]]))
raw = jnp.reshape(raw, list(pts.shape[:-1]) + [4]) |
Hi myagues! Thanks for sharing your code. Unfortunately, your code does not work either. I tried to run it on Google Colab, and I increased the N_samples to 640 instead of 64. Then I got an OOM error. I am doing a project on NeRF and I really want this code to run well. I think it is worth figuring out why lax.map and jax.remat are not working in this case. In my understanding, the problem is that you need to get value from batches in order to calculate the loss, which is the distance between the whole-image-rendering and the image. Nonetheless, it worked quite OK in tensorflow 1.14, as is shown in the team's code: https://github.com/bmild/nerf. After experiments, I found that both my code in flax.nn and myagues' code in jax.experimental.stax had this problem. Maybe there should be some specific ways to deal with this issue. Any further ideas, everyone? @mattjj could you look further into the code, please? |
An alternative I tried was to hand-batchify the image into rows. I used a for loop to go through each batch, calculated the loss on each sub-image, summed up the grads, and then applied the summed-grad to the optimizer. Unfortunately, it was about 8-10 times slower than how it should be. |
Could you please elaborate on your point with the code? |
I'm currently implementing NeRF, a neural rendering method published this year with great performance. The original code was offered in TensorFlow, and I mainly modified from this vanilla lego example: https://colab.research.google.com/github/bmild/nerf/blob/master/tiny_nerf.ipynb
The runnable version of my code can be found in this Github repository: https://github.com/BoyuanJackChen/NeRF-Implementation
You can run "/Code/NeRF-jax.py" directly. I put a jax profiler exporter under raw = batchify(network_fn)(pts_flat). You can see that the memory usage increases proportionally as you increase N_samples variable..
While transferring it to JAX, I encountered a problem in batchifying the training data. I hope that instead of hand-batchifying, which is troublesome, buggy, and slow, there could be some equally elegant way to do the job as the counterpart in TensorFlow. I hereby offer the code down below. Hope to receive some enlightenment!
So the idea of NeRF is that for each pixel in an image, shoot one ray through that pixel, sample N points along the ray, and use the model to calculate the (r,g,b,alpha) value on each point, then summing them back to get the rgb value for that pixel. The loss is then calculated from the difference between the model-generated image and the original image.
The problem is that calculating so many sample points for over 8k pixels at once is too memory-consuming. Therefore, the original code offered a way to batchify in Tensorflow:
With this, the TensorFlow is supposedly learning the grads from one chunk of pts at a time.
I tried to do the same thing by simply changing tf.concat to jnp.concatenate. Nonetheless, the code is simply not batchifying, and my memory soon explodes! I mean, I can certainly hand-batchify by splitting the pts_flat with reshape, and calculate the grads for each sub-array. But it is going to get ugly and vulnerable to inefficiency. I wonder if there are good ways to solve it with some quick solutions.
By the way, I checked some previous threads on jax.remat. I attached "@jax.remat" above render_rays, and it didn't work. I also tried to attach it above batchify without success, saying "<class 'function'> is not a valid JAX type". Maybe I just didn't use it in the right way. Hope you guys can help me!
I can provide further information if you need!
The text was updated successfully, but these errors were encountered: