Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compile times #150

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open

Fix compile times #150

wants to merge 2 commits into from

Conversation

DhairyaLGandhi
Copy link
Member

DenseNet had a major regression in the compile time to differentiate it over the releases.

This is often times due to very long Chains. This is a small fix that makes things a lot more manageable for the moment.

julia> den = DenseNet();

julia> ip = rand(Float32, 224,224,3,1);

julia> @time gradient((m,x) -> sum(m(x)), den, ip); #before
473.703960 seconds (114.84 M allocations: 7.870 GiB, 0.61% gc ti
me, 99.48% compilation time)

julia> @time gradient((m,x) -> sum(m(x)), den, ip); # after
209.761373 seconds (103.33 M allocations: 7.180 GiB, 1.30% gc time, 98.81% compilation time)

This is a pattern we have across the library, so maybe something to fix elsewhere as well.

Comment on lines +31 to +32
Chain([conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)...,
MeanPool((2, 2))]...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Chain([conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)...,
MeanPool((2, 2))]...)
Chain([conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)]..., MeanPool((2, 2)))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type for conv_bn is already a Vector, so shouldn't just Chain(conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)..., MeanPool((2, 2))) work? Also, I know this suggestion has been shot down before because it would cause visual noise, but simply tweaking conv_bn to return a Chain does wonders for the TTFG:

master:

julia> using Metalhead

julia> using Flux: Zygote

julia> den = DenseNet();

julia> ip = rand(Float32, 224, 224, 3, 1);

julia> @time Zygote.gradient((m,x) -> sum(m(x)), den, ip);
 77.621622 seconds (124.76 M allocations: 11.324 GiB, 1.67% gc time, 97.00% compilation time)

with conv_bn returning a Chain:

julia> @time Zygote.gradient((m,x) -> sum(m(x)), den, ip);
 28.244888 seconds (89.40 M allocations: 9.049 GiB, 3.60% gc time, 90.78% compilation time)

Copy link
Member

@theabhirath theabhirath Apr 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ This needs some tricks to get this fast though. One major trick being that large Vectors that are being splatted to give Chains....should not be (Flux 0.13 deals with this, so this works). Removing a single splat to a large vector of layers (the "body" of the DenseNet) makes it shoot back up:

julia> @time Zygote.gradient((m,x) -> sum(m(x)), den, ip);
 46.788491 seconds (117.59 M allocations: 10.873 GiB, 2.65% gc time, 94.90% compilation time)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woops, you are indeed right and the suggestion looks good.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I am curious about is the large discrepancy b/w first compiles on master. I regularly get ~500s TTFG with DenseNet, you don't seem to get nearly as bad times. Mine is with GPUs turned off. Does that make up some of the difference?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am testing on an M1 Mac CPU, with 4 threads and Julia master. Maybe some of the discrepancy is there? Julia 1.8+ seemed to be an order of magnitude faster than Julia 1.7 last I checked for compilation of some stuff

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants