-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wip] [pipeline parallel] t5 - experiment #9765
Conversation
Thanks @stas00 , I am getting what looks like a torch error when I run this (I'm not sure if the "Failed to look up the IP address for the hostname" error is related -- I'm not able to find much on this except for an issue from a few days ago that mentions this: pytorch/pytorch#50700 ):
Output:
|
It looks more like a warning as it recovers with a fallback, make sure you have:
It looks like I forgot to commit the last change. My apologies. Could you please update and try again? |
Thanks -- appears to be working -- on t5-3B it spreads it evenly across the 4 A100s (13.0-13.3GB each with a batch size of 1). For t5-11B there's an out of memory error -- I suppose (naively) if 11b is ~3.7x larger than 3B then it would require ~49gb per card without some form of offloading? |
Thank you for confirming that you were able to use it with t5-3b on your 4 gpus. Were you able to get a decent gpu utilization across the board? Or were they all under 25%? Please make sure you read my notes on the balancing in OP and experiment with the device map so that all gpus get a balanced GPU memory usage. gpu0 is already busy with many things, so I'd try a spread of 2/4/4/4 parts or perhaps 1/3/3/3 in your definition of:
in this example we have 1/3 parts balance between gpu 0 and 1. i.e. 3 times more layers for gpu 1. Of course, it needs to be adjusted to 4 gpus and I don't remember how many encoder blocks t5-11b has, but as I mentioned if you look at the logs you will find a ready map there, just re-adjust it to balance things better. Please let me know if I communicated clearly what I'm trying to say - we want all 4 gpus to have about the same memory usage - then we maximize the chance to fit t5-11b on those 4 gpus. Next we need to try to bolt DeepSpeed on it. So we will try to use 2 gpus for pipeline and 2 gpus for ZeRO-DP and perhaps some ZeRO-Offload too. I should get access to 4 gpus soon and I will start working on figuring that step out. I will post back once I have something practical to share. |
Thanks -- the autobalancing (just "chunks=4") actually seemed to give nearly entirely even results on -3B (the ~13.0-3GB each), so I tried that with 11B instead of manually supplying the device map (since it seemed a bit uneven when I tested on -base) -- but I'll tinker on 11B and report back. |
FYI, currently the automatic device map just tries to split |
What's interesting is that I'm not generally observing GPU0 to have a higher load. Here's an example with unifiedqa-t5-3b (essentially just a further pre-trained t5-3b, not relevant here), chunks=4, autobalancing (with a different visualization tool). They all tend to show about the same RAM usage over time. The graph also shows the utilization (also generally under 30% most of the time): BTW -- I tinkered with different manual device_map settings for t5-11b, but it always quickly gave out of memory errors. |
Oh, what tool is that? I want it too! It looks like different GPUs behave differently, it will take some experimentation to make sense of it all. But clearly you're also not seeing much benefit from the pipeline over the native MP. Same as I. Either my workaround to make it work slow everything down or there is another problem elsewhere. As I mentioned I'd like to redesign my implementation in hope
Thank you for the experimentation. I'm still waiting to get access to a 4-gpu setup and when it happens will immediately start experimenting with bolting DeepSpeed on it and then will get back to you. |
Thanks -- this handy cool visualization tool is nvtop -- I just found it to plot the relative changes rather than stare at nvidia-smi and hope to keep it all in my brain. It's available with apt ( sudo apt-get install nvtop ). Happy to offer my rig for some testing if you need a 4 GPU setup sooner. :) |
Oh, yes, I had it and forgot about its usefulness. Thank you! I typically use
but this is a way better.
If don't find access by tomorrow I will gladly accept your generous offer, @PeterAJansen. Thank you! |
hmm, how do you get a split screen per card in nvtop? for some reason my version reports both cards as one card. I don't see any command line options to configure that. |
hmmm, it actually worked out-of-the-box for me (but looks very different depending on the dimensions of the terminal). Does it show only 1 GPU (with memory for both?), or two separate GPUs? |
Sorry, I forgot to mentioned I tried this already to no avail. I give it a huge console. I even tried various terminals - same. I think it may have to do with my 2nd card being rtx-3090 - and it doesn't work with cuda < 11.1 - most likely nvtop was built against cuda-10, so while it replicates the nvidia-smi stats, it can't access nvml for that card and thus doesn't show the graph. Yup, installed nvtop on a machine with 2 normal gpus and it shows them both in the same-size terminal. So it just can't handle rtx-30* unless it's rebuilt from source against cuda-11.1+ But even then when it works it gives no way to separate the 2 gpu other than colors and 4 lines often around the same magnitude for different things are impossible to make sense of. This is an odd design. |
:-/ That's unfortunate (though I suppose the cost of using bleeding-edge hardware). The A100s are supported with CUDA 11.0, so they must just squeak in on the current version available with apt. (And, the usability is a little unusual, but with ASCII graphics there are strong limits... :) ) |
pytorch w/ cuda-11.2 nightly should be available any day now. cuda-11.2 has been out for a month now.
This is a good point. But at least one could use a double line or asterisks or something to differentiate 4 different things. Perhaps some people can track 4 similar colors and remember which is which. Not me. I guess the source code is there, if I really need to I could probably hack it to do be more user-friendly. |
Update: this overload of the term MP to mean totally different things is a big problem. I was sure I could easily combine non-DeepSpeed pipeline with Deepspeed after reading So this particular branch takes us nowhere closer to integration of PP with DeepSpeed. Back to the drawing board. |
# rewrite the model after pre-trained weights were loaded | ||
layers = [ | ||
T5StackPipeSegment( | ||
idx, | ||
n_layers, | ||
layer_module, | ||
self.is_decoder, | ||
head_mask[idx], | ||
encoder_head_mask[idx], | ||
output_hidden_states, | ||
use_cache, | ||
output_attentions, | ||
all_hidden_states_add, | ||
present_key_value_states_add, | ||
all_attentions_add, | ||
all_cross_attentions_add, | ||
) | ||
# layer_outputs is a tuple with: | ||
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) | ||
hidden_states, present_key_value_state = layer_outputs[:2] | ||
|
||
# We share the position biases between the layers - the first layer store them | ||
# layer_outputs = hidden-states, key-value-states (self-attention weights), | ||
# (self-attention position bias), (cross-attention weights), (cross-attention position bias) | ||
position_bias = layer_outputs[2] | ||
if self.is_decoder and encoder_hidden_states is not None: | ||
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] | ||
# append next layer key value states | ||
if use_cache: | ||
present_key_value_states = present_key_value_states + (present_key_value_state,) | ||
|
||
if output_attentions: | ||
all_attentions = all_attentions + (layer_outputs[3],) | ||
if self.is_decoder: | ||
all_cross_attentions = all_cross_attentions + (layer_outputs[5],) | ||
|
||
# Model Parallel: If it's the last layer for that device, put things on the next device | ||
if self.model_parallel: | ||
for k, v in self.device_map.items(): | ||
if i == v[-1] and "cuda:" + str(k) != self.last_device: | ||
hidden_states = hidden_states.to("cuda:" + str(k + 1)) | ||
for idx, layer_module in enumerate(self.block) | ||
] | ||
# block_sequential = nn.Sequential(*layers) | ||
|
||
# for now don't enable the pipe | ||
if self.pipeline_is_enabled: | ||
|
||
# print("using partitioning: ", dict(zip(devices, layer_splits))) | ||
for device_id, layer_partition in self.device_map.items(): | ||
for layer_id in layer_partition: | ||
# print(f"{layer_id} => {device_id}") | ||
layers[layer_id].to(device_id) | ||
|
||
block_sequential = nn.Sequential(*layers) | ||
block_pipe = Pipe(block_sequential, chunks=self.pipeline_chunks, checkpoint="never") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, we are creating the entire pipeline model each time in the forward pass. Isn't this going to be pretty expensive? Why don't we just create the Pipe only during init and just use it in the forward pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because if you look a bit higher where I create T5StackPipeSegment
to populate layers
I don't have all those arguments available at init time - i.e. I only get those during forward. But I can pass them to forward of course. I was just trying to come up with ways to overcome the Tensor/tuple(Tensor) restriction and at the time this seems like a good workaround. If I don't do it, the passing of arguments to forward becomes even more complicated.
But as you started looking into in order to identify why there is no speed up over naive MP, it could be the cause - I have no idea how long does it take to build these 2 pipes - but surely it's expensive doing it on every model run.
My first experiment was just to make the pipe work at all costs, so clearly this is not a way forward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And btw, continuing our discussion from the pytorch issue, we also can't easily change the shape of the model, because we need to adhere to the pre-trained model layout to be able to load pre-trained models. I suppose it could be re-mapped after the normal dist state is loaded. So that's another complication I was trying to overcome here - injecting proxy layers that had no weights of their own.
too long. closing. |
We will test this branch soon. |
There are probably some things that can be salvaged from this PR, but the main utility of it is to see the difficulties I run into. And of course, this is not a good solution not only because the code is absolutely nuts, but because it's very inefficient. As I mentioned in the other thread, pytorch now has a better API, so some of the encoding/decoding of non-tensor inputs/outputs I did won't be needed anymore as it now supports non-tensor inputs/output. |
This PR is not ready for reviews.
I'm putting it up primarily for those who want an early preview of a possible Pipeline solution. @PeterAJansen, you wanted to see if you could get it working with 4x 40GB rig and t5-11b. Please give it a try.
Intention
We want to replace the naive model parallel (MP) implementation with a more efficient pipeline parallel (PP) implementation, which takes advantage of all participating gpus, and not not having one gpu run and the rest idling which is the case with the naive MP.
To give you a visual from the GPipe paper,
You will find a new argument
chunks
, which is how many pipeline stages you want to add, in the 2nd diagram of the image oabove you can see thatchunks=4
.So with
chunks=1
you get the naive mp, but it'd be even slower than the naive MP because of the RPC overhead.Overview
Porting t5 to Pipeline Parallelism proved to be a study in hacking, due to the very restrictive original pipeline interface which only allows tensors or tuples of tensors as
input
/output
arguments inforward
, and intransformers
we have a ton of very complex variables to pass toforward
and return from it.We are trying to change the Pipeline design to be much more user-friendly: pytorch/pytorch#50693
This implementation tries to take advantage of 2 natural stacks, so I implemented it as 2 pipes:
6 for
t5-small
.Please don't even bother looking at the code, it is one big hack which took many hours to come up with to make the pipeline work, so clearly it is not something very portable or readable.
Setup
important: you need pytorch-nightly to be able to use this.
Just create another conda env not to mess up with your normal env, but pt-nightly is a solid peace of software, I use it all the time. here is a quick copy-n-paste of what you will need - just edit the location of the transformers checkout dir.
Down the road I will look at using also fairscale/deepspeed but for now pytorch is just more accessible and hopefully will be more flexible soon.
Deployment: script
You can deploy PP directly via your own trainer/script, e.g. this is what I have been using while developing it:
Deployment: HF Trainer
But you can also use HF trainer. I tweaked the trainer to activate PP with:
This will let the program do the partitioning for you. But you can control the partitioning manually by passing:
Here we basically pass the equivalent of a dict
{0: [0, 1, 2], 1: [3, 4, 5, 6, 7, 8, 9, 10, 11]}
which btw, you can pass in your script as:The syntax is what you'd pass to
range
, so `device_map=0:0-3,1:3-12" is the same as:the keys are the gpu ids.
The number of layers is at the moment just the depth of the encoder stack, so 12 for t5-base, 6 for t5-small, etc.
Later we should have a different way as well, where we define the desired balance, rather than the specific layers.
Since each
t5
model has a different number of blocks, the easiest way is to first run without the device map and then check the logger output which will show you which device map it's using. Then I recommend to re-balance it so that gpu0 has less layers than the remaining gpus.Benchmarks
example for 2x 24GB cards
Performance-wise:
chunks
so that there is enough in the pipe so that the gpus don't idle, but not too big as performance goes down. But I wasn't able to overcome 50%/gpu utilization, so it's not much more different from the naive implementation - don't know yet why - probably data copying takes most of the overhead.Here are some stats on 2x 24GB Titan RTX:
Baseline: (1gpu)
XXX: need to re-test with rebased code-base
Now with pipeline:
XXX: need to re-test with rebased code-base
Future
I'm also trying to instrument this feature with reporting that will help users to finetune chunks/device_map
This is the
model.pipeline_finalize()
call. Things I'm thinking that would be useful:Any other ideas/requests/needs?
@PeterAJansen, please let me know if you managed to run this on your 4x gpu setup.
Next, I think I'm going to scratch the current implementation and try a new one afresh.
Also this PR should be good enough to try to figure out how to use with DeepSpeed, once I get access to 4 gpus (need at least 4 gpus to do 2D parallelism).
I did warn you not to look at the code.
I also removed big chunks of MP code for now as it was getting in the way with the noise, will restore it when I sorted this all out.