Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[model parallelism] Bart goes parallel #9384

Closed
wants to merge 9 commits into from

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Jan 2, 2021

This PR implements model parallelism (MP) in Bart.

This is the latest incarnation of generalization of the MP in transformers, based on @alexorona's original work. I have done some of it already in #9323 and this PR builds upon the other one. It's slightly complex what to merge when, but this PR is independent and can be merged on its own.

For reviewers I propose to read things in this order:

  1. [t5 model parallel] misc fixes #9316
  2. [T5 model parallel] implement input auto-relocation + lots of refactoring/cleanup #9323
  3. this PR
  4. Additional important design discussions Model Parallelism and Big Models #8771

If all is in agreement, I propose:

  1. ☐ merging this PR first,
  2. ☐ then I'll backport the new code from this PR to [T5 model parallel] implement input auto-relocation + lots of refactoring/cleanup #9323 and we merge that.
  3. ☐ then we handle gpt2, which I haven't touched yet. Perhaps @alexorona could help there if his time permits or one of us.
  4. ☐ complete Bart's other heads (can be item 3) and deparallelize - the latter is not really needed in practice so will handle those when dust around design settles.
  5. ☐ add Bart to trainer's supported for --model_parallel flags
  6. ☐ write tests for model_parallel_utils.py
  7. ☐ meanwhile we can polish the concept of device maps which will require a review of all architectures transformers has implemented.

Actually first we need to merge smaller bits:

  1. [trainer] --model_parallel hasn't been implemented for most models #9347
  2. replace apex.normalization.FusedLayerNorm with torch.nn.LayerNorm #9386

So this PR:

  • Implements MP in Bart based on discussions in all of the threads/PRs listed above. Only BartForConditionalGeneration at the moment while we are sorting out the API. But the bulk of the work is done, since BartModel has all in place.

  • switches to the concept of main_device rather than (first|last)_device so the first device of encoder becomes the main_device and almost everything happens there (embeddings, lm_head, etc), and other devices are used exclusively for encoder and decoder purposes.

  • switches to a more explicit device_map that can support non-symmetrical models (different number of layers in encoder and decoder). It can also handle different types of maps. See the demo at the end this post for details.

  • further improves the magical to() functions that can operate on any type of variable except opaque objects. Can be used to put the inputs on the correct devices either automatically via a forward decorator or explicitly inside forward. We could use either or both.

  • adds a bunch of debug functions that make it easy to trace device IDs of variables, params and whole layers.

  • further improves the device map validation function

  • improves tests

  • needs to remove apex.normalization.FusedLayerNorm as it's buggy under MP (corrupts data) per replacing apex.normalization.FusedLayerNorm with torch.nn.LayerNorm #9377 a dedicated to removal PR is replace apex.normalization.FusedLayerNorm with torch.nn.LayerNorm #9386

Here is a quick demo (you will need 2 gpus to run it):

from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig

#mname = "sshleifer/tinier_bart"
mname = "sshleifer/distilbart-xsum-6-6"

model = BartForConditionalGeneration.from_pretrained(mname)
tokenizer = BartTokenizer.from_pretrained(mname)

sentences = ["I'm sitting here in a boring room. It's just another rainy Sunday afternoon. I'm wasting my time I got nothing to do. I'm hanging around I'm waiting for you. But nothing ever happens. And I wonder."]
inputs = tokenizer(sentences, max_length=1024, return_tensors='pt', truncation="longest_first")

device_maps_flat = {
    "sshleifer/tinier_bart": {
        "encoder": {0: [0, 1] },
        "decoder": {1: [0] },
    },
    "sshleifer/distilbart-xsum-6-6": {
        "encoder": {0: [0, 1, 2, 3, 4, 5] },
        "decoder": {1: [0, 1, 2, 3, 4, 5] },
    },
}

device_maps_split = {
    "sshleifer/tinier_bart": {
        "encoder": {0: [0],
                    1: [1],
                    },
        "decoder": {1: [0] },
    },
    "sshleifer/distilbart-xsum-6-6": {
        "encoder": {0: [0, 1, 2],
                    1: [3, 4, 5],
                    },
        "decoder": {0: [0, 1, 2],
                    1: [3, 4, 5],
                    },
    },
}

# 3 different ways (2 different device maps and 1 autogenerated device map)
model.parallelize() # autogenerated
#model.parallelize(device_maps_flat[mname])
#model.parallelize(device_maps_split[mname])

inputs = inputs.to("cuda:0")
# Generate Summary
summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=25, early_stopping=True)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
# prints: [" I'm sitting in a room where I'm waiting for something to happen."]

You can see from the demo, that when calling model.parallelize you can skip the device_map arg altogether and the model will generate the right one. Or you can provide one that:

  1. gives some gpus exclusively to encoder and others to decoder
  2. splits the model horizontally so that the encoder uses all gpus so so the decoder
    the model transparently handles all the remappings

Note, the user still needs to put the data on the main_device, so perhaps that will eventually become not hardcoded via:

# inputs = inputs.to("cuda:0")
inputs = inputs.to(model.main_device)

As we have been discussing elsewhere the device map format is unstable yet. So I propose we document it as unstable yet, but the users can rely on the autogenerated device map which requires no input from the user (i.e. calling `model.parallelize() ) - if it changes it'll happen transparently for the user.

Also note that in situations of Trainer-based scripts, like finetune_trainer.py, the user has no way to supply such device map at the moment so in effect the model generates the map on the fly as in the above para.

Fixes: #8344

@LysandreJik, @patrickvonplaten, @sgugger, @alexorona

@patrickvonplaten
Copy link
Contributor

That looks great! Model parallelism would be very nice for Bart. We should coordinate here a bit with all the open PRs. I'm also more or less done with the big "split-bart-into-separate-models" PR: #9343.
Think the merge conflicts can become a bit painful here :D.

I'd propose the following:
-Merge: #9347, #9386 (they should be pretty trivial to merge)
-Rebase and Merge the big Bart refactor (#9343)
-Discuss/Merge the "new" model parallel design: #9316 and #9323
-Rebase and Discuss/Merge this PR

@LysandreJik
Copy link
Member

Is this PR ready for review? There's a lot of comments that were probably here for debugging purposes. Let me know if you want a review or if we should come back to it after #9347 and #9386 have been merged.

@stas00
Copy link
Contributor Author

stas00 commented Jan 4, 2021

It's very ready, functionality/concept-wise. It's not ready 100% commentary, debug traces, etc. but that's very unimportant until the rest is sorted out, since there are multiple overlapping PRs happening.

Because of the holidays there is a lot of new code which is all inter-dependent and unreviewed and then there is a huge change to merge of #9343.

So I think it's the best to review it as it is - sort things out and then once everybody is happy with the logic, and #9343 I will most likely have to do a new PR anyway.

But I need your feedback that what I did is correct.

Think of it as an ongoing code design and generalization PR.

Thanks.

@stas00
Copy link
Contributor Author

stas00 commented Jan 4, 2021

@patrickvonplaten, your plan works for me. Please ping me when #9343 is in and you're happy with the outcome so that it'll be safe to add MP w/o creating conflicts. Thank you.

But as I commented above this blocking event doesn't have to interefere with this PR's review - we are just not going to merge it, but it should proceed as normal and I will take all the agreed upon changes to the new PR once the dust around Bart split settled down.

@alexorona
Copy link
Contributor

alexorona commented Jan 4, 2021

@stas00 @LysandreJik @patrickvonplaten This PR introduces a device_map that is not backwards compatible with 4.1.0. We have to do that at some point (as @stas00 discovered), but let's not have three different versions. We really need to make sure that we have consensus on the final form of the device_map that will work for all models going forward or we will have to change it again when model parallelization is generalized and some of its functionality is placed in PreTrainedModel. Have you tested this on gpt2, @stas00 and is the code generalizable to models that don't have decoder architectures and can store their attention blocks in attributes like self.h?

Has everyone read this comment? Are we all on board for the plan to generalize model parallelism? Don't have to implement it now, but we need to make sure we've thought through any changes that affect user experience and backward compatibility.

Sorry, I'm in the middle of moving so not keeping close track of all the traffic and could easily have missed something. Also, this content is spread across several PRs, so sometimes I'm getting confused.

@stas00
Copy link
Contributor Author

stas00 commented Jan 5, 2021

@alexorona, I'm basically leaving all the old code in place, so that gpt2 works as is and t5 as is, so this PR only impacts Bart. And in any case it doesn't look like this PR will be merged since Bart went through a split #9343, which isn't finalized yet and I will need to re-do it anyway. But it's no problem, since I know what to do. And see the end of this comment - the whole MP implementation might need to be redesigned altogether.

Since there are so many moving parts, it's very difficult to manage things and definitely makes things difficult for reviewers.

So my intention was to merge each of the new things separately, while keeping the old code working and then to start integrating things in bit. The holidays made things pile up, but since the HF team is back I trust in the next few days we will form a plan.

Important notes:

  1. MP is new here and should be clearly marked as an experimental feature. Which means device maps are not fixed and can change at any moment. [model parallel] add experimental warning #9412

    What we could commit to is having the default device map work - i.e users don't supply any device map and then it just works.

    That's why I propose we start with each model implementing its own device map format (while sharing bits with common code where possible) and then over time we will find a common format.

    If the HF team wants to allocate time then we need to sit down, look at all the models and decide on the format ahead of time. If I'm not mistaken it looks like at the moment it's just @alexorona and I that mostly understand what's going on, so it'd be great to have someone from HF to get on top of MP. I'd be happy to sit down with that person and explain what I learned in the last few weeks in person. It's not complicated.

  2. As it was just pointed out rfc: automating the switching of inputs to the device of the params pytorch/pytorch#49961 (comment) this implementation is highly inefficient since it doesn't take advantage of the idle gpus, so we might have to scratch a big part of it and re-implement it using PP or something similar. The current implementation just uses extra gpus to expand available memory, but doesn't take advantage of the extra hardware.

Until then we have deepspeed integration almost ready and sharded_ddp should be available in the next few days, so users will have excellent ways to fit huge transformers models on limited hardware already. So let's not rush with MP here and think.

@LysandreJik
Copy link
Member

From what I understand, model parallelism as it's currently implemented is a naive implementation of what it's supposed to do: offer more memory so that bigger models may be trained using memory of several devices instead of a single device. It is indeed inefficient as devices as idle while others compute, so there's definitely a way of making it more efficient.

@stas00, @alexorona, if you could walk us through what you have learned so that @patrickvonplaten, @sgugger and myself can understand the different options available, that would be great.

Some thoughts to structure our approach towards MP:

  • You mention pipeline parallelism (PP) as a way to be more efficient than model parallelism (MP), as the idle devices can be used while other compute. This intuitively seems like an algorithm to set up during training, do you think we would have to modify the models themselves like what is currently done with model parallelism?
  • As noted by @sgugger and approved by @patrickvonplaten and myself, working on the MP API of the current models (GPT-2 and T5) is a better test-bed than trying to make it work for all models all at once. Most models are similar, and finding a generic approach (if possible!) should be feasible with just these two models for now.
  • You're right that we should not rush it, and take our time to understand what we can do best for both inference and training.

@alexorona
Copy link
Contributor

alexorona commented Jan 5, 2021

@LysandreJik No, it's not a naïve implementation of model parallelism. In addition to data parallelism and model parallelism, there is pipeline parallelism, which is the next level of complexity along with zero redundancy. Model parallelism allows us to train bigger models on GPU. Pipeline parallelism + model parallelism would allow us to train these large models faster because the GPUs are not idle. I really think the next step is to make sure model parallelism is generalized and rely on a library -- probably deepspeed -- to implement pipeline parallelism and zero redundancy. deepspeed has something called 3D parallelism, which I believe is a form of pipeline parallelism. @stas00 is that correct?

From my understanding, deepspeed has three major enhancements:

  • 3D parallelism
  • zero-redundancy that reduces the GPU memory footprint of a given module
  • some support for clusters, but I'm hazy on the details

Practical feature implications: We can currently train t5-11b -- I believe the largest model in the library -- in a practical and affordable amount of time on the newest cloud instances. There are three benefits to pursuing pipeline parallelism and zero redundancy libraries:

  • Users could train large models faster
  • Users could train large models on more modest hardware
  • We would be prepared for the eventual release of even larger models in the 20 billion and potentially up to 100 billion parameter range

@stas00
Copy link
Contributor Author

stas00 commented Jan 6, 2021

Some notes following up to the raised issues:

  • I need to study and experiment before I'm able to answer a lot of the questions you have been asking. For example one of the important questions @alexorona asks is whether the idling GPUs can be utilized to a high capacity by integrating other libraries like deepspeed. I will be able to answer that once I try that.

  • The "naive" part @LysandreJik referred to is that, say, you spread the model over 8 gpus - 7 gpus will be idling most of the time, so it'd a terribly expensive training as you would be paying per gpu and not per its utilization. So while the current solution works there must be a more efficient ways to do that. One much more complex solution suggested here: rfc: automating the switching of inputs to the device of the params pytorch/pytorch#49961 (comment) is with the RPC mechanism. Again, I don't have any experience with it, so I will eventually get to try it and comment back.

  • DeepSpeed's solution to huge model size is ZeRO - while it says it can support models implementing MP, it says it's not needed since we have a working solution (100B param model w/o needing MP) and my experiments showed that with sharded DDP on my weird hardware setup I can fit 3x more data, and with DeepSpeed 3-5x, and that's just with some default config.

  • We are on the same page wrt to making things working on a few models - t5, gpt2 and bart is ready too. Note that Bart is a better candidate than t5 because it can be asymmetrical wrt encoder/decoder-size - so it's slighly more complex (but not by much). We were discussing a specific issue of device_map design, which requires us to look at all models. But that's where it can stop.

My plan is to finish the DeepSpeed integration - almost there and then look into Pipelines next.

Of course, nobody needs to wait for me, I'd be just as happy for others to experiment and teach me instead ;)

I commented on the current design so that the HF team better understand what we have here:
#8771 (comment)
Let's keep the design discussion focused in one thread, otherwise we are all over multiple threads... doesn't matter which - just pick one... If you have questions or need for clarifications please don't hesitate to ask.

@stas00
Copy link
Contributor Author

stas00 commented Jan 6, 2021

I rebased on #9343 so now it's no longer possible to develop anything on Bart - the check fails because it wants all copy-cats to be the same:

python utils/check_copies.py
Traceback (most recent call last):
  File "utils/check_copies.py", line 305, in <module>
    check_copies(args.fix_and_overwrite)
  File "utils/check_copies.py", line 166, in check_copies
    raise Exception(
Exception: Found the following copy inconsistencies:
- src/transformers/models/pegasus/modeling_pegasus.py: copy does not match models.bart.modeling_bart.BartAttention at line 141
- src/transformers/models/marian/modeling_marian.py: copy does not match models.bart.modeling_bart.BartAttention at line 140
- src/transformers/models/marian/modeling_marian.py: copy does not match models.bart.modeling_bart.BartEncoderLayer at line 275
- src/transformers/models/marian/modeling_marian.py: copy does not match models.bart.modeling_bart.BartDecoderLayer at line 331
- src/transformers/models/blenderbot_small/modeling_blenderbot_small.py: copy does not match models.bart.modeling_bart.BartAttention at line 124
- src/transformers/models/blenderbot_small/modeling_blenderbot_small.py: copy does not match models.bart.modeling_bart.BartEncoderLayer at line 259
- src/transformers/models/blenderbot_small/modeling_blenderbot_small.py: copy does not match models.bart.modeling_bart.BartDecoderLayer at line 315
- src/transformers/models/mbart/modeling_mbart.py: copy does not match models.bart.modeling_bart.BartAttention at line 133
- src/transformers/models/blenderbot/modeling_blenderbot.py: copy does not match models.bart.modeling_bart.BartAttention at line 126
Run `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them.
make: *** [Makefile:25: extra_quality_checks] Error 1

How do I move forward with my work then? I suppose the only way to proceed is to drop Bart and use one of the derivatives? So Bart isn't going MP...

@patrickvonplaten, @sgugger

@sgugger
Copy link
Collaborator

sgugger commented Jan 6, 2021

That's also why we should pause the BART PR for MP and make sure the general API is solid enough. Any change in BART will impact all related models (that was true before the split, since the other models were subclasses) so the same PR will need to do BART/Pegasus/mBART/marian etc. And probably the ses2seq template. So better make sure we're happy with the design on a model independent from the others like GPT-2 or T5 first :-)

@patrickvonplaten
Copy link
Contributor

I rebased on #9343 so now it's no longer possible to develop anything on Bart - the check fails because it wants all copy-cats to be the same:

python utils/check_copies.py
Traceback (most recent call last):
  File "utils/check_copies.py", line 305, in <module>
    check_copies(args.fix_and_overwrite)
  File "utils/check_copies.py", line 166, in check_copies
    raise Exception(
Exception: Found the following copy inconsistencies:
- src/transformers/models/pegasus/modeling_pegasus.py: copy does not match models.bart.modeling_bart.BartAttention at line 141
- src/transformers/models/marian/modeling_marian.py: copy does not match models.bart.modeling_bart.BartAttention at line 140
- src/transformers/models/marian/modeling_marian.py: copy does not match models.bart.modeling_bart.BartEncoderLayer at line 275
- src/transformers/models/marian/modeling_marian.py: copy does not match models.bart.modeling_bart.BartDecoderLayer at line 331
- src/transformers/models/blenderbot_small/modeling_blenderbot_small.py: copy does not match models.bart.modeling_bart.BartAttention at line 124
- src/transformers/models/blenderbot_small/modeling_blenderbot_small.py: copy does not match models.bart.modeling_bart.BartEncoderLayer at line 259
- src/transformers/models/blenderbot_small/modeling_blenderbot_small.py: copy does not match models.bart.modeling_bart.BartDecoderLayer at line 315
- src/transformers/models/mbart/modeling_mbart.py: copy does not match models.bart.modeling_bart.BartAttention at line 133
- src/transformers/models/blenderbot/modeling_blenderbot.py: copy does not match models.bart.modeling_bart.BartAttention at line 126
Run `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them.
make: *** [Makefile:25: extra_quality_checks] Error 1

How do I move forward with my work then? I suppose the only way to proceed is to drop Bart and use one of the derivatives? So Bart isn't going MP...

@patrickvonplaten, @sgugger

I agree with @sgugger that it would be better to just work on TF and GPT2 until we have a solid API for now...But in general the idea is to implement the feature in Bart and then run make fix-copies and all other models are updated automatically. In case you add a lot of code to Bart (outside of BartAttention) it can very well be that this code has to be manually copied inside the other models as well (happy to help then :-) )

@patrickvonplaten
Copy link
Contributor

And big sorry for making this PR so much harder for you now! But that Bart split had to happen sooner or later

@stas00
Copy link
Contributor Author

stas00 commented Jan 6, 2021

And big sorry for making this PR so much harder for you now! But that Bart split had to happen sooner or later

Surprisingly, the rebasing was super-simple. So it wasn't a hurdle at all.

@stas00
Copy link
Contributor Author

stas00 commented Jan 6, 2021

  1. Bart and t5 aren't exactly the same, so in order to generalize a variety of models is needed.
  2. And this PR is much further ahead than t5, albeit I can spend more time merging it back into t5.

If I switch to one of the original subclasses, say, MBart, and work with it instead - will the copy-checker complain just the same?

@sgugger
Copy link
Collaborator

sgugger commented Jan 6, 2021

If I switch to one of the original subclasses, say, MBart, and work with it instead - will the copy-checker complain just the same?

I'm afraid so, unless you remove all # Copied from comments, but that defeats the purpose.

@stas00
Copy link
Contributor Author

stas00 commented Jan 6, 2021

Understood. thank you!

It sounds like this change will make future development of the bart family somewhat painful. Since the developer will have to constantly sync multiple files with their new development and it won't help the reviewers since now there will be multiple duplicated diffs.

It'd be much more useful to run the check/sync periodically or at will, rather than enforcing them on each make style, IMO. I guess time will tell.

@stas00
Copy link
Contributor Author

stas00 commented Jan 6, 2021

Thinking more about the situation - the thing is - this PR works - I put a ton of work into it - users can start using MP with the Bart family yesterday, e.g. with --model_parallel flag in trainer - we don't have to expose the unstable device map and using the internal default device map is sufficient for most simple uses. And if we change to a different more efficient implementation down the road - it'd be totally transparent to the users. And if it's not trainer, they can just use model.parallelize() without the device map, or use the device map but know it may change down the road.

I'd just need to enable self.is_parallelizable that was just added and clean up a bit.

But it's your call.

@sgugger
Copy link
Collaborator

sgugger commented Jan 6, 2021

e.g. with --model_parallel flag in trainer

That's one of the thing to clean up: this flag is not necessary with the current API: we can detect if a model is parallelized and avoid a confusion with the name. I'm not saying we should throw this PR in the thrash, just that it should be paused until we have had time to do all clean up we want.

@stas00
Copy link
Contributor Author

stas00 commented Jan 7, 2021

e.g. with --model_parallel flag in trainer

That's one of the thing to clean up: this flag is not necessary with the current API: we can detect if a model is parallelized and avoid a confusion with the name.

Do tell more? Are you planning to launch MP just because a model supports it? It sounds that you are considering dropping the --model_parallel cl arg in trainer

Or are we talking about different things?

I'm not saying we should throw this PR in the thrash, just that it should be paused until we have had time to do all clean up we want.

tldr;

  1. I'm fine with hitting the pause button as you suggested.
  2. this is a fully functional implementation - so you actually can send users to this PR branch if they want to use MP with Bart (as the family has been cast out after I rebased on master, it will require quite some work to re-add it to other Bart-like models).

the full story:

The issue is simple. Is that things are complicated. This PR overlaps with #9323 - both have multiple changes and improvements, and I have already documented and commented on each one of the changes in both PRs, well actually 3 PRs (this one too #9316), so leaving such code spread out over several PRs is a recipe for a huge mess down the road. It all came to be as I was working over the holidays and wasn't getting feedback (No complaints, I'm just explaining how it came to be.). As a result of it I was working on new changes but with Bart so that I could see how to generalize better. Not knowing what you'd decide I tried to leave the existing code without any API changes, hence the separate independent PRs.

The bottom line is this. Regardless of whether the current implementation is efficient or not, it works. And any future more efficient implementation will use the same API on the user-side (or perhaps something more complicated) - at the moment its just one command to turn the feature on.

So you actually can send users to this PR branch if they want to use MP with Bart-only.

So the other approach I can take is to merge parts of this PR into t5-mp PR #9323, but it'll be again a lot of work and nobody has even looked at any of those PRs...

But then we are talking about perhaps finding a more efficient solution, and perhaps deepspeed will render a lot of it pointless anyway... (Alex thinks not.) So why waste reviewers' time... makes sense not to.

So yes, let's freeze this up and I go back to work on deepspeed.

I have convinced myself it's the right thing to do and you got to hear my inner talk.

Just remember it's totally functional in case someone needs it.

Thank you for reading.

@stas00
Copy link
Contributor Author

stas00 commented Jan 7, 2021

As t5 MP is broken in the trainer, I needed to see if it was the same with my Bart MP port - but it works:

rm -r output_dir; CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=../../src USE_TF=0 ./finetune_trainer.py --model_name_or_path  sshleifer/distilbart-xsum-6-6  --output_dir output_dir --adam_eps 1e-06 --do_eval --do_train --evaluation_strategy=steps --fp16 --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size 4 --per_device_train_batch_size 4 --predict_with_generate --eval_steps 25000 --save_steps 25000 --sortish_sampler  --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 1 --n_train 2 --n_val 2 --n_test 2 --do_predict --task summarization --data_dir xsum --model_parallel 

So with this PR you can use --model_parallel automatically with out trainer scripts with Bart models.

@stas00
Copy link
Contributor Author

stas00 commented Jan 11, 2021

As I was trying to see if I can find a way to utilize the idling GPUs, I run these benchmarks - haven't found anything useful yet, but the interesting finding is that while we get a huge performance hit with evaluation and beam size > 1, actually the training time is faster than non-MP version, despite all the data copying

This PR beats master on training time almost by half 8.6sec vs 15.8 sec, but of course it has 2 gpus vs 1 gpus!!! But it beats even the DDP solution 10.6sec by 20%!

So perhaps there is something good in here we just need to understand why is it faster than DDP.

Unfortunately I have an uneven GPUs setup, so it's hard to get very useful benchmarks. Perhaps someone with 2 identical GPUs could re-run these and report back.

For posterity here are the results I'm getting with 1x 8gb and 1x 24gb gpus:

# w/o MP w/o DDP


rm -r output_dir; CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=../../src USE_TF=0 ./finetune_trainer.py --model_name_or_path  sshleifer/distilbart-xsum-6-6  --output_dir output_dir --adam_eps 1e-06 --do_eval --do_train --evaluation_strategy=steps --fp16 --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size 4 --per_device_train_batch_size 4 --predict_with_generate --eval_steps 25000 --save_steps 25000 --sortish_sampler  --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 1 --n_train 200 --n_val 200  --task summarization --data_dir xsum

2021-01-10 16:57:43 | INFO | __main__ |   train_runtime = 15.8407
2021-01-10 16:58:02 | INFO | __main__ |   val_runtime = 19.0772

# w/o MP w/ DDP


rm -r output_dir; CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=../../src USE_TF=0 python -m torch.distributed.launch --nproc_per_node=2 ./finetune_trainer.py --model_name_or_path  sshleifer/distilbart-xsum-6-6  --output_dir output_dir --adam_eps 1e-06 --do_eval --do_train --evaluation_strategy=steps --fp16 --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size 4 --per_device_train_batch_size 4 --predict_with_generate --eval_steps 25000 --save_steps 25000 --sortish_sampler  --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 1 --n_train 200 --n_val 200  --task summarization --data_dir xsum

2021-01-10 16:58:42 | INFO | __main__ |   train_runtime = 10.6299
2021-01-10 16:58:53 | INFO | __main__ |   val_runtime = 11.4454

# w/ MP  w/o DDP

rm -r output_dir; CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=../../src USE_TF=0 ./finetune_trainer.py --model_name_or_path  sshleifer/distilbart-xsum-6-6  --output_dir output_dir --adam_eps 1e-06 --do_eval --do_train --evaluation_strategy=steps --fp16 --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size 4 --per_device_train_batch_size 4 --predict_with_generate --eval_steps 25000 --save_steps 25000 --sortish_sampler  --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 1 --n_train 200 --n_val 200 --model_parallel --task summarization --data_dir xsum

2021-01-10 16:49:00 | INFO | __main__ |   train_runtime = 8.6264
2021-01-10 16:51:14 | INFO | __main__ |   val_runtime = 134.0955

runtime is very slow due to beam search (==4).

same w/ --eval_beams 1

2021-01-10 16:56:10 | INFO | __main__ |   train_runtime = 8.657
2021-01-10 16:56:41 | INFO | __main__ |   val_runtime = 31.4318


# w/ MP  w/ DDP

rm -r output_dir; CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=../../src USE_TF=0 python -m torch.distributed.launch --nproc_per_node=2 ./finetune_trainer.py --model_name_or_path  sshleifer/distilbart-xsum-6-6  --output_dir output_dir --adam_eps 1e-06 --do_eval --do_train --evaluation_strategy=steps --fp16 --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size 4 --per_device_train_batch_size 4 --predict_with_generate --eval_steps 25000 --save_steps 25000 --sortish_sampler  --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 1 --n_train 200 --n_val 200 --model_parallel --task summarization --data_dir xsum

this doesn't work: can't mix this implementation of MP w/ DDP

AssertionError: DistributedDataParallel device_ids and output_device arguments only work with single-device GPU modules, but got device_ids [0], output_device 0, and module parameters {device(type='cuda', index=0), device(type='cuda', index=1)}.

@github-actions
Copy link

github-actions bot commented Mar 6, 2021

This issue has been automatically marked as stale and been closed because it has not had recent activity. Thank you for your contributions.

If you think this still needs to be addressed please comment on this thread.

@github-actions github-actions bot closed this Mar 6, 2021
@stas00 stas00 added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Mar 6, 2021
@stas00 stas00 reopened this Mar 6, 2021
@stas00 stas00 removed the wontfix label Mar 9, 2021
@stas00
Copy link
Contributor Author

stas00 commented Jun 4, 2021

too long. closing.

@stas00 stas00 closed this Jun 4, 2021
@fabrahman
Copy link

Hello, @stas00 is there any update on BART based model parallelism? also about model.parallelize() for BlenderBot? Thanks.

@stas00
Copy link
Contributor Author

stas00 commented Jun 5, 2022

This line of work has been abandoned as it's highly inefficient. Please use DeeepSpeed which works with any model https://huggingface.co/docs/transformers/main/main_classes/deepspeed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Model Parallel Model Parallelilsm Implementations WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

model parallelism for BART
6 participants