Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[model_utils] very slow model instantiation #9205

Closed
stas00 opened this issue Dec 19, 2020 · 21 comments · Fixed by #11471
Closed

[model_utils] very slow model instantiation #9205

stas00 opened this issue Dec 19, 2020 · 21 comments · Fixed by #11471
Assignees
Labels
Performance WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@stas00
Copy link
Contributor

stas00 commented Dec 19, 2020

For some reason I'm noticing a very slow model instantiation time.

For example to load shleifer/distill-mbart-en-ro-12-4 it takes

  • 21 secs to instantiate the model
  • 0.5sec to torch.load its weights.

If I'm not changing how the model is created and want to quickly fast forward to the area of debug how could these slow parts be cached and not rebuilt anew again and again?

But also it looks like we are doing a completely wasteful operation of init_weights, which immediately get overwritten with pretrained model weights (#9205 (comment)) (for the use case of pre-trained model).

(I initially made a mistake and thought that it was torch.load that had an issue, but it's cls(config, *model_args, **model_kwargs)) - thank you, @sgugger - so this post has been edited to reflect reality. So if you're joining later you can skip the comments up to #9205 (comment) and continue from there)

@patrickvonplaten, @sgugger, @LysandreJik

@sgugger
Copy link
Collaborator

sgugger commented Dec 20, 2020

Doesn't that script also loads and preprocess the data? From what you're reporting, I don't interpret this as "transformers takes a long time to load the model" (since the line that does that takes the same time as a torch load) but as "stuff that happens in that script before the model loading takes a lot of time" (which is probably data preprocessing + the 3s to import transformers if TF is in your env). Or am I missing something?

@stas00
Copy link
Contributor Author

stas00 commented Dec 20, 2020

Perhaps my first post is confusing, what I did is bracketing the torch.load call in modeling_utils.py:

        start_time = time.time()
        state_dict = torch.load(resolved_archive_file, map_location="cpu")
        end_time = time.time() - start_time

So all the other stuff isn't being measured, just the torch.load call.

@sgugger
Copy link
Collaborator

sgugger commented Dec 20, 2020

Ah, I understand better. I don't think your comparison is fair: AutoModel.from_pretrained does two things: creating a model and filling it with the weights. From a small experiment in timing on my side, I believe all the time is spent in the model creation. So you should compare the timing of creating the model and loading the weights inside to have something that's apple to apple.

@stas00
Copy link
Contributor Author

stas00 commented Dec 21, 2020

I removed the 2nd part that was showing the same issue from a different angle, as it appears to just confuse and isn't contributing to understanding the issue at hand.

There is just state_dict = torch.load(resolved_archive_file, map_location="cpu") call - and nothing else. On its own:
python -c "import torch; torch.load('/hf/transformers-master/data/distill-mbart-en-ro-12-4/pytorch_model.bin')"
it takes ~1s, the exact same call inside modeling_utils takes 22+ secs.

@stas00
Copy link
Contributor Author

stas00 commented Dec 21, 2020

OK, somehow I made a mistake and was taking the snapshot of startime before model = cls(config, *model_args, **model_kwargs) and not torch.load() - my apologies :( and thank you for double checking my invalid report.

        import time
        t0 = time.time()
        model = cls(config, *model_args, **model_kwargs)
        t1 = time.time()
        state_dict = torch.load(resolved_archive_file, map_location="cpu")
        t2 = time.time()
        print(f"cls init { round(t1-t0, 4)}")
        print(f"load     { round(t2-t1, 4)}")
        import sys
        sys.exit(0)

cls init 21.2055
load     0.5074

So it's setting up the model that takes so long, just as you said.

Can this somehow be sped up? I was integrating deepspeed and re-running the same command repeatedly and 23 extra secs of waiting to just discover that something is off was very painful for debugging. All the failures happened at much later stages. I worked around it it by switching to a tiny model, but even that takes some secs.

Can we think of a way to make an image and load it rather than rebuilding the model from scratch? So we torch.load the weights, but also cache the model image itself and load it too, rather then create it anew. It seems to be so wasteful and slow if I'm not debugging the model creation but say tuning up something in the trainer and I want the other parts to load blazingly fast and get me to the point of interest quickly. What would be the best way to approach such need?

@stas00 stas00 changed the title [model_utils] very slow model loading [model_utils] very slow model instantiation Dec 21, 2020
@stas00
Copy link
Contributor Author

stas00 commented Dec 21, 2020

So doing profiling on model instantiation code it can be seen that _init_weights is where some 75% of that slowdown happens

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      354   18.942    0.054   18.942    0.054 {method 'normal_' of 'torch._C._TensorBase' objects}
      225    2.286    0.010    2.286    0.010 {method 'uniform_' of 'torch._C._TensorBase' objects}

So we are completely wasting time doing init weights, since we are immediately replacing them. (with the exception to SinusoidalPositionalEmbedding which do not get loaded from the pretrained model).

If you prefer the visual version:

snapshot_2

Chances are that model init needs to be made context aware and not init weights which will be immediately replaced. Thoughts?

That would make transformers so much faster to start! (e.g. think the model pages website which takes forever to load a model).

The profiling was done with:

# prep
pip install graphviz gprof2dot
cat <<EOT > prog
from transformers import AutoModelForSeq2SeqLM
AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distill-mbart-en-ro-12-4")
EOT

# text profile
USE_TF=0 PYTHONPATH=src python -m cProfile -s tottime prog > profile.txt
head -10 profile.txt

# visual profile
USE_TF=0 PYTHONPATH=src python -m cProfile -o profile.pstats prog
gprof2dot -f pstats profile.pstats |  dot -Tsvg -o callgraph.svg
display callgraph.svg

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Dec 21, 2020

If we see a significant gain in loading time, maybe it's worth to explore a way to only apply init_weights on missing layers. Not sure how easy it would be to implement it though...

Maybe a init_weights function arg in __init__ might make sense:

model = cls(config, init_weights=False, *model_args, **model_kwargs)  # don't call init_weights, but initialize all weights to zero because it's much faster
# load weights into model and get missing layers
# init missing layers

@sgugger
Copy link
Collaborator

sgugger commented Dec 21, 2020

Yeah Patrick's suggestion is probably the best, though I'm not sure it can easily be achieved in the current API. Note that this is only one slowdown at the beginning of training, so I don't think this should be high priority.

@stas00
Copy link
Contributor Author

stas00 commented Dec 21, 2020

I totally get it that it's not high priority, since most people don't care for a slow start when they run it non-stop for hours - it only affects people who need a quick start - which is the case when debugging something or as I suggested the demo function on the model pages which takes a really long time to load.

In the case of BART, its deterministic segments do the init internally, so it's enough to just monkeypatch as a proof of concept:

        # modeling_utils.py::from_pretrained
        init_weights_orig = PreTrainedModel.init_weights
        def init_weights_pretrained(self):
            # self.apply(self._init_weights)
            if self.config.pruned_heads: self.prune_heads(self.config.pruned_heads)
            self.tie_weights()
            
        PreTrainedModel.init_weights = init_weights_pretrained
        model = cls(config, *model_args, **model_kwargs)
        PreTrainedModel.init_weights = init_weights_orig

and this command:

PYTHONPATH=../../src USE_TF=0 time python -c 'from transformers import AutoModelForSeq2SeqLM; AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distill-mbart-en-ro-12-4")'

goes from 25sec to 8secs. The instantiation goes from 22 secs to 5 secs.

There are few uniform_ calls left which account for 2.3 extra secs, which if shaves off we should be down to 2-3 secs (from 22!).

I quickly checked that the core functions normally - same scores - well, I did just one finetune_trainer run.

One way is to solve this as @patrickvonplaten suggested, and I'm also thinking of changing the design a bit. So that each model has a normal init_weights and init_weights_pretrained - then it's very clear to the developer what goes where and then simply invoke one or the other depending on the context. And then it's just a matter of choosing how to signal the context.

@sgugger
Copy link
Collaborator

sgugger commented Dec 22, 2020

I don't see how you could have an init_weights_pretrained: it depends on the checkpoint you pass: if you pass the checkpoint of a BertModel to BertForMaskedLM, you just have one bias to initialize (if weights are tied). But if you pass a checkpoint of a BertForMaskedLM checkpoint then you have nothing to initialize. And the same holds for every variant (which would have different specific weights to initialize in case of a pretrained model) so I don't really see how you can do this API-wise.

The only way I see through it is to allow the init_weights to get the list of model parameters to randomly initialize, but since we use the apply method afterward (and rely on it to get modules inside each model specific _init_weights method) I don't see how to use it properly. It would probably require some clever recursive method.

Again, lots of headaches and possibilities for errors for an end result that doesn't strike me as high priority.

it only affects people who need a quick start - which is the case when debugging something or as I suggested the demo function on the model pages which takes a really long time to load.

It doesn't take 25 seconds on a tiny model, only a big one. So I'd suggest debugging on a tiny model :-)

@stas00
Copy link
Contributor Author

stas00 commented Dec 22, 2020

Thank you both for entertaining possible approaches and suggesting that you are not quite seeing a smooth solution. I just don't know enough about all of it, so I'm surely missing on cases I haven't thought of, but somehow in my mind it looks simple. The devil is in the details.

It doesn't take 25 seconds on a tiny model, only a big one. So I'd suggest debugging on a tiny model :-)

Unfortunately the tiny model approach doesn't work with debugging OOM in deepspeed, as its configuration correlates to the model size. I guess it's not special to deepspeed at all. So the tiny model trick works for checking mechanics (i.e. that the code compiles), but isn't helpful for OOM debug.

@stas00
Copy link
Contributor Author

stas00 commented Feb 4, 2021

@patrickvonplaten, @sgugger, @LysandreJik - could we please revisit this - working on making t5-11b train was painful - it was taking really really really long time to init the model, just to drop it and replace with pre-trained weights. Transformers is mainly about pre-trained models, so perhaps this can be made somehow configurable?

We know when a pretrained model is loaded, so why not propagate that information and let the model know it's being loaded in pre-trained mode, so that it could skip any weight inits that are going to be replaced anyway?

And while we are at it, I don't suppose there is a way to involve more than one CPU core in loading the model? I guess that would be a question for pytorch.

Thank you!

@patrickvonplaten
Copy link
Contributor

I'm happy to add such a featurue. It should be feasible to only initialize those layers that are not in the saved .pt file.

@LysandreJik
Copy link
Member

Indeed, this would be a welcome feature, big models aren't going away.

@github-actions
Copy link

github-actions bot commented Mar 6, 2021

This issue has been automatically marked as stale and been closed because it has not had recent activity. Thank you for your contributions.

If you think this still needs to be addressed please comment on this thread.

@stas00
Copy link
Contributor Author

stas00 commented Mar 6, 2021

@patrickvonplaten, I should probably work on it - since it doesn't seem like you will have time any time soon.

@stas00 stas00 reopened this Mar 6, 2021
@stas00 stas00 self-assigned this Mar 6, 2021
@stas00 stas00 added Performance WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress and removed wontfix labels Mar 6, 2021
@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Mar 8, 2021

It's on my To-Do List, but still don't think, I'll be able to take a look within the next 2,3 weeks - sorry :-/ If you find some time for this, it would be great

@stas00 stas00 mentioned this issue Mar 31, 2021
18 tasks
@AyeshaSarwar
Copy link

Ihave finetuned a longformer encoder decoder model, and trying to convert it into an api but model takes too long to load that api throws a not responding error.
code
Kindly if anyone can guide me on how can I reduce the time for the model to load.
Thank You in advance.

@patrickvonplaten
Copy link
Contributor

Hello @AyeshaSarwar,

could you please use the forum: https://discuss.huggingface.co/ instead for such questions? We don't support Flask compatibility in transformers. Please keep in mind that the issues are mainly used for issues related to just transformers.

Thanks

@DeXtmL
Copy link

DeXtmL commented Sep 29, 2022

Im on the same boat as @stas00 . I understand that the code need to maintain a wider compatibility across the oceans of models, but people needs a working workaround before an elegant solution born into reality. I believe as huggingface slowly graduating from pure research field, more and more people are being hurt by the tremendous model initialization time.
Hoping for a change

@stas00
Copy link
Contributor Author

stas00 commented Sep 29, 2022

@DeXtmL, this thread is 2 years old - the particular problem I raised in this Issue has been solved a long time ago. The model is no longer being init'ed twice.

If you feel something is still slow please start a new Issue.

thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Performance WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
6 participants