Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory and running time issues for CNN #29

Closed
HornHehhf opened this issue Apr 13, 2020 · 11 comments
Closed

Memory and running time issues for CNN #29

HornHehhf opened this issue Apr 13, 2020 · 11 comments
Labels
enhancement New feature or request

Comments

@HornHehhf
Copy link

Hi,

 I currently use the neural tangents to compute the kernel for CiFAR-10 images. I need to compute the kernel matrix for 10000 images x 10000 images and there are 3x32x32 pixels each image. If I use a 2-layer feedforward NNs with reshaped input 3072, it took me about 3G memory and several minutes to compute the kernel.

However, if I use a simple CNN network (one layer CNN), it will output an error with "failed to allocate request 381T memory". I can only reduce the size of minibatch each time. But it will make the computing process quite slower. And this is just one-layer CNN, I expect it will cost more time for multilayer CNN. And even for one batch (100 images), it still costs much more time than the 2-layer feedforward NNs.

Another strange thing is that I expect that I should be able to compute the kernel matrix for batch size 200 (out of 10000) each time because the server has a memory of 394G.  But it is still out of memory (manually checked) after running several minutes and killed without error prompt.

So I am wondering how to use your tools to compute the kernel matrix for CNNs. It either costs too much memory or too much time in my end. Do you have any suggestions to deal with this issue?  I am not sure about your latent mechanism to compute the kernel for CNN. But I expect it shouldn't cost so much memory and run so slow, because [Arora et al' 2019](https://arxiv.org/pdf/1904.11955.pdf) compute the kernel for 21-layer CNN.

It is really a good tool but I hope that you can help with the CNN memory and running time issue.

Thanks,
Hangfeng

@romanngg
Copy link
Contributor

Hey Hangfeng, I'm afraid that computing CNN kernels with pooling layers is in general very costly, so I don't have a good solution for you right now. Some specific thoughts:

  1. 381T sounds like roughly the size of (10000 * 32 * 32)^2 float32 covariance matrix, which you need to compute as an intermediary step if your network has pooling layers. You need to use nt.batch decorator on your kernel_fn, and the batch_size would need to be somewhere in the range of 10-100 depending on your CPU/GPU RAM.

  2. I'm not certain why batch size of 200 runs out of memory (are you running it on CPU RAM?), but I suspect it may be due to JAX having to allocate 3-4X of the tensor size (instead of 2X) due to Buffer donation to a jit function on GPU jax-ml/jax#1273 Add support for buffer donation (input/output aliasing) jax-ml/jax#1733, in which case there isn't much we can do (please upvote those issues!), and you would need to scale down the batch size even further. If this explanation is correct, I think a batch size of 130-150 should work.

  3. If you can, please post your code snippet and perhaps I can see if there are any performance improvements possible. But in general it seems about right to me that the FC kernel takes minutes while CNN+Pooling kernel takes thousands of hours, since in intermediary layers, they effectively work with kernels of sizes 10000^2 and (10000 x 32 x 32)^2 respectively. Note that one middle ground is to not have any pooling layers and have a stax.Flatten() layer at the top - in this case the effective kernel size will be 10000^2 x 32x32. [If you are seeing your errors with stax.Flatten() and no pooling layers - please let me know, as in this case it might be a fixable bug]

  4. FYI, under the hood CNN kernel computation (both with and without pooling layers) uses convolutions, which AFAIK are much more efficient on GPUs than CPUs, especially for some common settings (e.g. 3x3 kernels, SAME padding). This might be a secondary reason why the computation is especially slow, and you might get noticeably better results with a GPU and a much smaller batch size (it will be very costly still though).

We'll definitely take note to prioritize CNN+pooling performance, but FYI I'm not aware of simple ways of improving it currently, so it might take a while.

@HornHehhf
Copy link
Author

Thanks for your quick reply. I want to double-check your current suggestions are the following two things:

  1. Try a CNN without pooling layer will occupy less memory.
  2. Use GPU instead of CPU with batches.

BTW, another question is:
Do you have any ideas about why the CNTK implementation in Arora's paper seems not so costly? Their model is not as generalizable as yours, but maybe their implementations can give your some insights on dealing with the CNN issue.

Thanks again for your quite reply. It is really helpful!

@romanngg
Copy link
Contributor

  1. Yes, i.e. something like stax.serial(stax.Conv(...), stax.Relu(), stax.Conv(...), ..., stax.Flatten(), stax.Dense(10)). If this does not work considerably faster and much less memory-hungry - let me know, this would mean there's a bug. It should still be slower than fully-connected kernel though.

  2. Yes, the batch size will have to be ~O(10) if you use pooling, but might still be faster due to GPU efficiency.

  3. That's definitely on my todo list! They seem really efficient, at the top of my head, suspect reasons:

  • Our current implementation computes the CNN kernel propagation as a sequence of convolutions with the identity matrix kernel, i.e. for a CNN layer with a 3x3 kernel we convolve the covariance matrix with
[[1, 0, 0],
 [0, 1, 0],
 [0, 0, 1]]

twice. So the effective operation has 3 summands (diagonal), but actual computation sums 3x3 = 9 summands, the whole kernel. I imagine (haven't studied their code yet) they might have implemented this efficiently directly in CUDA, while we are using the JAX/XLA primitives available, and among those this seems like the best solution. A while ago we did try some tricks to reduce the operations to 3x1 convolutions, but this only yielded us ~1.3X speedup (and not expected 3X in this case), so perhaps there is also some overhead related to using JAX/XLA and not CUDA directly.

This is a bit handwavy, I'll definitely need to benchmark/debug in more detail to pinpoint the reason precisely. Hope this helps, lmk if you have any other questions!

@HornHehhf
Copy link
Author

Yes, the CNN without pooling is much faster, which is acceptable to me. Thanks for your help.

@jaehlee
Copy link
Collaborator

jaehlee commented Apr 14, 2020

@HornHehhf I just want to add to @romanngg's reply that CNTK code requires order O(1000) GPU hours and process through small batches to deal with large memory issue. Their custom CUDA kernel is more efficient that our library using JAX/XLA primitives at the moment, nonetheless inherently these kernels are compute intensive.

@HornHehhf
Copy link
Author

Got it, thanks a lot!

@romanngg romanngg added the enhancement New feature or request label Apr 15, 2020
@romanngg
Copy link
Contributor

Btw, I just made a change to how CNN kernel is computed without pooling, it should give about ~25% speedup on GPU (no improvements to CNN w/ pooling though).

252ed85

@romanngg
Copy link
Contributor

Good news: I found a hack that speeds up CNNs w/ pooling by >4X on NVIDIA GPUs! 100afac, should be there in NT >= 0.2.1.

We should now be comparable in performance to Arora et al, but as @jaehlee remarked, the kernel computation is still inherently costly. This is not yet extensively tested though, so please let us know if there are any issues!

I've added some benchmarks to https://github.com/google/neural-tangents#performance, you can use them to estimate how long your task should take.

One sad takeaway from the table though is that even a very beefed-up CPU is 40X slower than a single V100 on this task, which makes it especially ill-suited. I noticed that there is very low CPU utilization when doing single-channel CNNs (which we do), and filed a bug with the XLA team, hope they can help us with this!

Finally, if you're aiming for top performance on images, you probably do need pooling layers, so there is a tradeoff between speed (Flatten) and task accuracy (GlobalAvgPool). We discussed this phenomenon in https://arxiv.org/pdf/1810.05148.pdf (Figure 1, section 5.1).

@HornHehhf
Copy link
Author

Great, thanks for your notes. I want to double-check that the speed comparison between the CPU and GPU. For CNN with pooling, CPU is 40X slower than a single V100. How about the CNN without pooling? Does the speed difference still hold for the CNN without pooling? The time for computing the CNN without pooling on CPU is acceptable to me, but I still want to double-check the speed comparison between GPU and CPU for the CNN without pooling.

@romanngg
Copy link
Contributor

Yes, without pooling the ratio also seems to be in 30-40X.

Otherwise, CNN-Flatten seems about 1000X faster than CNN-Pool, which makes sense since the covariance tensor size in that case is 32 * 32 = 1024 times smaller.

@HornHehhf
Copy link
Author

Great, thanks very much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants