Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved time to first gradient #151

Merged
merged 3 commits into from
May 1, 2022
Merged

Conversation

theabhirath
Copy link
Member

@theabhirath theabhirath commented Apr 27, 2022

Edit: Initially this had some benchmarks that weren't completely accurate because I'd left the REPL running and it wasn't the first Zygote.gradient. The DenseNet benchmark is pretty accurate in this regard.

This PR (building on the work done by @DhairyaLGandhi in #150) uses a Flux v0.13 feature (namely, the fact that Chain(::Vector) is valid syntax, along with returning a Chain as a output from conv_bn to halve compilation time for most models (and for some models, even better). From a cold start (first Zygote.gradient):

julia> model = DenseNet();

julia> ip = rand(Float32, 224, 224, 3, 1);

julia> @time Zygote.gradient((m,x) -> sum(m(x)), model, ip);
 78.400696 seconds (124.71 M allocations: 11.321 GiB, 1.71% gc time, 96.65% compilation time)
julia> @time Zygote.gradient((m,x) -> sum(m(x)), model, ip);
 28.161918 seconds (88.19 M allocations: 8.970 GiB, 3.66% gc time, 89.48% compilation time)

@theabhirath
Copy link
Member Author

Seems to have eased up some memory pressure as well - Ubuntu tests on nightly are now passing 🎉

@ToucheSir
Copy link
Member

This is the first time that compilation time for a first gradient has gone under 90%. I can't believe my eyes. Is it safe to say that DenseNet TTFG is no longer a concern either?

@theabhirath
Copy link
Member Author

theabhirath commented Apr 27, 2022

Well, there's probably a way to get it down even further but for now, this improvement looks pretty surreal (Chain(::Vector) is an absolute beast 😳).

Before:

julia> model = DenseNet();

julia> ip = rand(Float32, 224, 224, 3, 1);

julia> @time Zygote.gradient((m,x) -> sum(m(x)), model, ip);
 78.400696 seconds (124.71 M allocations: 11.321 GiB, 1.71% gc time, 96.65% compilation time)

This PR:

julia> @time Zygote.gradient((m,x) -> sum(m(x)), model, ip);
 28.161918 seconds (88.19 M allocations: 8.970 GiB, 3.66% gc time, 89.48% compilation time)

@theabhirath
Copy link
Member Author

theabhirath commented Apr 27, 2022

This is the first time that compilation time for a first gradient has gone under 90%

This might be slightly misleading, I think I left the REPL running 😅 The exact benchmarks of the improvements are varying over runs, but one thing is clear - in every case (including a completely fresh REPL), there's at least 2x improvement. There is also a common trend of about 17-18 seconds that Zygote itself takes to compile the first gradient call - not sure if there's some way that can go down, but that should help because I think the models are currently doing all they can

@theabhirath
Copy link
Member Author

More benchmarks. This is how long it took to run the tests in Feb:
153533882-57e73173-7400-48b9-90c5-71ba750177f1

And this is today, this PR:

Screenshot 2022-04-27 at 12 29 28 PM

Definitely a step in the right direction 😁

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Apr 27, 2022

there's at least 2x improvement

That's right and something I showed in #150 as well. Difference being that I see an order of magnitude more on first compile which is very curious.

17-18 seconds that Zygote itself takes to compile the first gradient call - not sure if there's some way that can go down

Yes, there is. I have experimented with precompile statements. In fact we used to do a call to gradient in Zygote during precompilation for exactly this reason. We were able to shave the "compile zygote" time by almost an order of magnitude iirc.

I think the models are currently doing all they can

That's not strictly the case 😅 . If you were to try an older Metalhead + Flux + Zygote, you would see ~4x faster TTFGs in some cases. There are still some tricks we can apply to get compilation pressure eased off, mostly to do with caching and stability.

@ToucheSir
Copy link
Member

There are still some tricks we can apply to get compilation pressure eased off, mostly to do with caching and stability.

I'm only aware of the switch from custom layer types to Chain, do you have any pointers to example code from back in the day that shows this?

@darsnack
Copy link
Member

Can you test this without returning a Chain for conv_bn but still using a Chain(::Vector) everywhere? Just to separate how much benefit is coming from each technique. DNNs are...deep, so having a "fix" for the long chains issue will be important outside of Metalhead.

This is really a sad state of affairs for Zygote that Julia 1.5 -> 1.6 caused such major performance regressions for the most basic operation in ML.

@darsnack
Copy link
Member

Even the "old" Metalhead used a flat chain for VGG, and DenseNet used a flat chain only replacing SkipConnection for a custom struct. Focusing on Metalhead as the source of regressions is a red herring IMO. There are issues on Flux/Zygote pointing out TTFG regressions with Zygote v0.6 for models not from Metalhead.

@darsnack
Copy link
Member

@ToucheSir I tried testing with FluxML/Zygote.jl#1195 and an afoldl implementation of Chain to see if it helps. Sadly, it doesn't appear to make a difference though I'm not sure that I set everything up correctly.

Another curiosity is that the CI test times have not gone down for this PR. @theabhirath are both the screenshots of the tests above with the same Julia version? I know you like to run nightly so I'd be curious if Julia versions are making a big difference here.

@theabhirath
Copy link
Member Author

theabhirath commented Apr 27, 2022

Can you test this without returning a Chain for conv_bn but still using a Chain(::Vector) everywhere

There doesn't seem to be much difference in gradient times but it shaves some time off the forward pass (returning Chain for conv_bn that is)

@theabhirath
Copy link
Member Author

theabhirath commented Apr 27, 2022

are both the screenshots of the tests above with the same Julia version?

Well, they're both nightly 😅 But there's been no major PRs to master that I think could've changed things this drastically, and nothing else has changed between the runs

Another curiosity is that the CI test times have not gone down for this PR

That may be limited by memory? I'm on a 16 gigs machine, while IIRC the runners have lesser to work with (7, I think?). Not sure if the difference should still show up in some fashion though

@ToucheSir
Copy link
Member

@darsnack that PR won't help TTFG much since Zygote + IRTools still has to churn through all of the control flow in afoldl. IIRC it should be strictly worse than using Chain{Vector}. The main benefits come at runtime.

What might help is optimizing the AD compilation pass itself. I have a local IRTools branch that shaves ~10s off TTFG for ViT through a combination of precompilation and reducing memory allocations in one particularly time-consuming function. However, it's unclear how much mileage is left for this approach, as profiling suggests a lot of time is spent in inference or LLVM. Perhaps 1.8/9 will help with those?

@darsnack
Copy link
Member

There doesn't seem to be much difference in gradient times but it shaves some time off the forward pass (returning Chain for conv_bn that is)

Just for clarity: you're saying that Chain(::Vector) contributes most of the TTFG improvement or nested Chains (as a result of returning Chain from conv_bn)?

@theabhirath
Copy link
Member Author

theabhirath commented Apr 27, 2022

Just for clarity: you're saying that Chain(::Vector) contributes most of the TTFG improvement or nested Chains (as a result of returning Chain from conv_bn)?

Chain(::Vector) primarily contributes to reducing TTFG. The nested Chains are helping reduce inference time a bit - 'bout 20-100 ms knocked off the forward pass depending on the model you check

@DhairyaLGandhi
Copy link
Member

Great, thanks @theabhirath ! I think this is good to go since there is plenty of improvement in there already and we can move ahead with the compilation tirade.

Returning nested Chains and returning Chain out of conv_bn actually contribute a lot.

One final thing would be to reenable testing gradients out of the models. Those are skipped currently.

@ToucheSir ToucheSir merged commit 792076f into FluxML:master May 1, 2022
@theabhirath
Copy link
Member Author

One final thing would be to reenable testing gradients out of the models. Those are skipped currently.

The memory issues on GA actions prevents this - testing locally does take a lot of memory (I've had to intervene to ensure it doesn't write too much into my swap)

@theabhirath theabhirath deleted the conv_bn branch May 1, 2022 02:41
@darsnack
Copy link
Member

darsnack commented May 1, 2022

This is my fault for not commenting, but I would actually prefer a follow up PR to remove the nesting. Not just because arbitrary nesting makes iteration and indexing inconvenient, but more practically because nesting is a breaking change. And it doesn't seem necessary when using Chain(::Vector) (which makes sense it should only affect the inference for tuples).

@theabhirath
Copy link
Member Author

theabhirath commented May 1, 2022

I have a ready followup PR, but here's something I noted while testing with the gradtests on.

With conv_bn returning a Chain, time taken to run the tests:
Screenshot 2022-05-01 at 5 22 18 PM

Without (i.e. manually splatting conv_bn everywhere):
Screenshot 2022-05-01 at 4 59 17 PM

Now I tested TTFG for some models and they stayed the same - but clearly, this shows that having conv_bn as a Chain helps subsequent gradients. Not sure what the way forward is - like I said, I have a ready follow-up PR in case we choose to revert but this approach seems to be yielding better results (the breaking change part is definitely annoying though, not sure how to circumvent that)

@darsnack
Copy link
Member

darsnack commented May 1, 2022

No a breaking change is okay if it is actually making a difference.

@darsnack
Copy link
Member

darsnack commented May 1, 2022

Why is the test time so different for AlexNet? It contains no conv_bns and the merged code doesn't attempt to use Chain(::Vector) either. i.e. shouldn't it be the exact same on those two screen shots?

I feel like we need a more rigorous benchmarking environment beyond ] test to make these decisions.

@theabhirath
Copy link
Member Author

True, this is completely local and it's not really much of a benchmark because I've just run the tests twice 😅

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants