-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory and running time issues for CNN #29
Comments
Hey Hangfeng, I'm afraid that computing CNN kernels with pooling layers is in general very costly, so I don't have a good solution for you right now. Some specific thoughts:
We'll definitely take note to prioritize CNN+pooling performance, but FYI I'm not aware of simple ways of improving it currently, so it might take a while. |
Thanks for your quick reply. I want to double-check your current suggestions are the following two things:
BTW, another question is: Thanks again for your quite reply. It is really helpful! |
twice. So the effective operation has 3 summands (diagonal), but actual computation sums 3x3 = 9 summands, the whole kernel. I imagine (haven't studied their code yet) they might have implemented this efficiently directly in CUDA, while we are using the JAX/XLA primitives available, and among those this seems like the best solution. A while ago we did try some tricks to reduce the operations to 3x1 convolutions, but this only yielded us ~1.3X speedup (and not expected 3X in this case), so perhaps there is also some overhead related to using JAX/XLA and not CUDA directly.
This is a bit handwavy, I'll definitely need to benchmark/debug in more detail to pinpoint the reason precisely. Hope this helps, lmk if you have any other questions! |
Yes, the CNN without pooling is much faster, which is acceptable to me. Thanks for your help. |
@HornHehhf I just want to add to @romanngg's reply that CNTK code requires order O(1000) GPU hours and process through small batches to deal with large memory issue. Their custom CUDA kernel is more efficient that our library using JAX/XLA primitives at the moment, nonetheless inherently these kernels are compute intensive. |
Got it, thanks a lot! |
Btw, I just made a change to how CNN kernel is computed without pooling, it should give about ~25% speedup on GPU (no improvements to CNN w/ pooling though). |
Good news: I found a hack that speeds up CNNs w/ pooling by >4X on NVIDIA GPUs! 100afac, should be there in NT >= 0.2.1. We should now be comparable in performance to Arora et al, but as @jaehlee remarked, the kernel computation is still inherently costly. This is not yet extensively tested though, so please let us know if there are any issues! I've added some benchmarks to https://github.com/google/neural-tangents#performance, you can use them to estimate how long your task should take. One sad takeaway from the table though is that even a very beefed-up CPU is 40X slower than a single V100 on this task, which makes it especially ill-suited. I noticed that there is very low CPU utilization when doing single-channel CNNs (which we do), and filed a bug with the XLA team, hope they can help us with this! Finally, if you're aiming for top performance on images, you probably do need pooling layers, so there is a tradeoff between speed ( |
Great, thanks for your notes. I want to double-check that the speed comparison between the CPU and GPU. For CNN with pooling, CPU is 40X slower than a single V100. How about the CNN without pooling? Does the speed difference still hold for the CNN without pooling? The time for computing the CNN without pooling on CPU is acceptable to me, but I still want to double-check the speed comparison between GPU and CPU for the CNN without pooling. |
Yes, without pooling the ratio also seems to be in 30-40X. Otherwise, CNN-Flatten seems about 1000X faster than CNN-Pool, which makes sense since the covariance tensor size in that case is 32 * 32 = 1024 times smaller. |
Great, thanks very much! |
Hi,
Thanks,
Hangfeng
The text was updated successfully, but these errors were encountered: