Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lax.map does not work in this NeRF code #4126

Open
BoyuanJackChen opened this issue Aug 22, 2020 · 8 comments
Open

lax.map does not work in this NeRF code #4126

BoyuanJackChen opened this issue Aug 22, 2020 · 8 comments
Assignees
Labels
NVIDIA GPU Issues specific to NVIDIA GPUs P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)

Comments

@BoyuanJackChen
Copy link

BoyuanJackChen commented Aug 22, 2020

This problem has been bothering me for months... I really want to know how to fix it. When implementing NeRF with jax, I want to divide the data into batches, so that the code can run on larger images with more sample points while keeping the speed. With the suggestions from another thread, I tried lax.map and jax.remat, but neither worked. I tried to hand-batchify and calculate each batch's loss with a for loop, but it was unacceptably slow (about x10 times slower).

In the original code published by the team, they used tensorflow 1.14. The line that they batched the input looks like this:

def batchify(fn, chunk=1024 * 32):
    return lambda inputs: jnp.concatenate(
        [fn(inputs[i : i + chunk]) for i in range(0, inputs.shape[0], chunk)], 0,
    )
raw = batchify(net_fn)(pts_flat)

In jax, I tried the following:

raw = lax.map(net_fn, jnp.reshape(pts_flat, [-1, batch_size, pts_flat.shape[-1]]))

and

def batchify(fn, chunk=1024*32):
    return jax.remat(lambda inputs : jnp.concatenate([fn(inputs[i:i + chunk])
                                         for i in range(0, inputs.shape[0], chunk)], 0))
...
raw = batchify(net_fn)(pts_flat)

Both were compilable, yet neither saved memory.

My guess on the cause of the problem is that since you have to get the whole rendering of the image to calculate loss, this batch procedure is not working well. Nonetheless, the same code from tensorflow 1.14 was in the exact same structure, and yet tf.GradientTape seemed to be able to batchify well.

I am grateful to receive code example from myagues in my previous post #3865. Unfortunately their code did not solve the problem, either. Here I offer the links to the two failed attempts in jax from me: https://github.com/BoyuanJackChen/NeRF-Implementation.git
and myagues:https://github.com/myagues/potpourri/blob/master/jax/tiny_nerf_jax.ipynb
I created the model with flax.nn, and myagues used jax.experimental.stax. Both were able to learn, yet neither could save the memory. If you want to take a look, I think my code is a little bit simpler to read.

I sincerely hope this problem to be fixed. I would be super grateful for any help!

@shoyer
Copy link
Collaborator

shoyer commented Aug 22, 2020

Can you try lax.map(jax.remat(net_fn), ...)?

The remat call has to decorate the function call for which you want to save only a single gradients checkpoint. Within remat, values from the forward get recalculated instead of saved.

Generally speaking there isn't any point to using remat only once, because you don't end up saving memory once you add back in the re-evaluation of the forward pass.

@BoyuanJackChen
Copy link
Author

BoyuanJackChen commented Aug 22, 2020

@shoyer The change in only one line didn't work. Could you elaborate on where else to add remat?

@BoyuanJackChen
Copy link
Author

BoyuanJackChen commented Aug 22, 2020

To make viewing convenient, the whole loss function looks like this:

        def loss_fn(network_fn):
            def batchify(fn):
                return jax.remat(lambda inputs: jnp.concatenate([fn(inputs[i:i + batchify_size])
                                                                 for i in range(0, inputs.shape[0], batchify_size)], 0))
            z_vals = jnp.linspace(near, far, N_samples)
            if rand:
                key, subkey = jax.random.split(thekey)
                z_vals += jax.random.uniform(subkey, list(rays_o.shape[:-1]) + [N_samples], dtype=jnp.float32) \
                          * (far - near) / N_samples
            pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
            pts_flat = jnp.reshape(pts, [-1, 3])
            pts_flat = embed_fn(pts_flat)  # pts_flat is an array of shape (H*W*N*3, 51)
            raw = lax.map(remat(network_fn), jnp.reshape(pts_flat, [-1, batchify_size, pts_flat.shape[-1]]))
            raw = jnp.reshape(raw, list(pts.shape[:-1]) + [4])
            sigma_a = nn.relu(raw[..., 3])  # (H, W, N)
            rgb = nn.sigmoid(raw[..., :3])  # (H, W, N, 3)
            dists = jnp.concatenate((z_vals[..., 1:] - z_vals[..., :-1],
                                     jnp.broadcast_to([1e10], z_vals[..., :1].shape)), -1)  # (H, W, N)
            alpha = 1. - jnp.exp(-sigma_a * dists)
            weights = alpha * jnp.cumprod(1. - alpha + 1e-10, axis=-1, dtype=jnp.float32)  # (H, W, N)
            rgb = jnp.sum(weights[..., None] * rgb, -2)
            loss = jnp.mean(jnp.square(rgb - target))
            return loss, rgb
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, logits), grad = grad_fn(optimizer.target)
        optimizer = optimizer.apply_gradient(grad)
        print(f"Step {i} done; loss is {loss}")

@cgarciae
Copy link
Collaborator

cgarciae commented Aug 22, 2020

@BoyuanJackChen I don't know if I am being too naive, but what if you just batch the data outside of Jax? I mean, just create a generator that produces numpy arrays of a certain batch size? You can easily do this with a tf.data.Dataset and use it with Jax by converting to an iterator of numpy arrays via its .as_numpy_iterator() method, you can even use Pytorch's Dataset + Dataloader without the to_tensor transformation.

BTW: Is there a reason why the authors implemented their own training loop in the colab you sent? It seems tf.keras.Model.fit could do the job.

@BoyuanJackChen
Copy link
Author

BoyuanJackChen commented Aug 22, 2020

@cgarciae Thanks for your advice! The iterator is a blind spot for me, so I am not sure how fast it could perform. Do you think it will be as fast as how tensorflow automatically batches data? Or is the batching using something like an iterator? One of my previous attempts was to divide the image by rows, (let's say 24 rows a batch), and then use a for loop to calculate the loss and grad for each batch. It was super slow but it worked. Is the iterator going to be of a similar structure? Could you enlighten me further with more details?

@cgarciae
Copy link
Collaborator

cgarciae commented Aug 22, 2020

@BoyuanJackChen tf.data uses a lot of tricks to keep the GPU busy, in particular the idea is to do the preprocessing and batching on the CPU in parallel to the forward + backprop steps on the GPU, so the GPU doesn't have to wait for each new batch to be ready.

Check out this video on tf.data.

@BoyuanJackChen
Copy link
Author

@cgarciae Another tricky part is that I don't put multiple images into one batch, but I need to divide each image into multiple batches. I guess I should make a tf.data for the rays of each image selected, and then call each batch with a for loop?

@BoyuanJackChen
Copy link
Author

BoyuanJackChen commented Aug 27, 2020

@cgarciae Thanks for your advice! It actually improved the code speed on cpu for a great extent! Below is a demo on how I used it:

target_data = tf.data.Dataset.from_tensor_slices(target_batched)
rays_o_data = tf.data.Dataset.from_tensor_slices(rays_o_batched)
rays_d_data = tf.data.Dataset.from_tensor_slices(rays_d_batched)
target_iter = iter(target_batched)
rays_o_iter = iter(rays_o_batched)
rays_d_iter = iter(rays_d_batched)
...
for target_batch, rays_o_batch, rays_d_batch in zip(target_iter, rays_o_iter, rays_d_iter):
    ...

Nonetheless, it still doesn't work on gpu - the code is even slower than for loop. My guess is that the iter is placed on cpu by default. Is there any way to make it run faster on GPU?

@sudhakarsingh27 sudhakarsingh27 added NVIDIA GPU Issues specific to NVIDIA GPUs P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) labels Aug 10, 2022
@sudhakarsingh27 sudhakarsingh27 self-assigned this Sep 7, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
NVIDIA GPU Issues specific to NVIDIA GPUs P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)
Projects
None yet
Development

No branches or pull requests

4 participants