-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tensor Parallelism] Megatron-LM to transformers #10321
Comments
@stas00 thanks for starting this thread! I guess, in order for everyone to be on the same page, a brief explanation of horizontal parallelism is needed. This would be a good place for future reference and introduce other contributors to the core concepts. NOTE for everyone reading: If you find any of the explanations below confusing, you can read about Megatron-LM in much more detail in its original paper: https://arxiv.org/pdf/1909.08053.pdf The core ideaThe main thing that separates Megatron-style (horizontal) parallelism from vertical parallelism is the way that it splits the model layers between GPUs without the need for idle time during training/inference (i.e. waiting while the previous GPUs complete their work on the previous layers of the model). This makes the whole process much more asynchronous, just like in MapReduce. Here's my rough sketch of how it looks: Now the question is, how do we split the computation of those layers so that the parallelized model weights would be equivalent to the CPU ones? Parallelized layersLet's start with a simple building block of any transformer: a fully connected layer (nn.Linear) followed by a nonlinear activation (GeLU). Following the Megatron's paper notation, we can write the dot-product part of it as If we look at the computation in matrix form, it's easy to see how the matrix multiplication can be split between multiple GPUs: Using this principle, we can update an MLP of arbitrary depth, without the need for any synchronization between GPUs until the very end, where we need to reconstruct the output vector from shards. The authors provide a helpful illustration for that: Quick note on self-attentionParallelizing the multiheaded attention layers is even simpler, since they are already inherently parallel, due to having multiple independent heads! Practical implementationIf you want to just dive right in, here are the basic building blocks implemented in Megatron-LM: All of these rely on basic def _split(input_):
world_size = get_tensor_model_parallel_world_size()
input_list = split_tensor_along_last_dim(input_, world_size)
rank = get_tensor_model_parallel_rank()
output = input_list[rank].contiguous()
return output
def _gather(input_):
world_size = get_tensor_model_parallel_world_size()
last_dim = input_.dim() - 1
rank = get_tensor_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
output = torch.cat(tensor_list, dim=last_dim).contiguous()
return output
class ScatterToModelParallelRegion(torch.autograd.Function):
def forward(ctx, input_):
return _split(input_)
def backward(ctx, grad_output):
return _gather(grad_output)
class GatherFromModelParallelRegion(torch.autograd.Function):
def forward(ctx, input_):
return _gather(input_)
def backward(ctx, grad_output):
return _split(grad_output) In a single transformer layer, there are 4 communication operations in total, for the forward and backward passes: Other things to considerParallelized embeddings and output logitsSince the weights of input and output embeddings of BERT/GPT2 are tied, they require a coordinated modification. In the original implementation, the input embedding matrix is parallelized along the vocabulary dimension (column-wise), and the output embeddings' matrix multiplications is parallelized together with the cross-entropy loss to reduce the communication size (see end of section 3 in the paper): Model parallelism-aware DropoutTransformers have dropout layers outside the model parallel regions before residual connections and within model parallel regions in the self attention block. Because some dropout layers are in a model parallel region, while others are not, we need to treat random number generation carefully to ensure dropout works correctly. See appendix B.2 in the paper for reference. Hybrid model and data parallelismCombining horizontal parallelism with data parallelism requires grouping the GPUs in a specific way, as described in appendix B.1: |
Phew! That felt like a start of a whole blog post 😄 As for porting all of this, I would follow fairseq's example and copy Megatron-LM's parallel layers verbatim into an existing (but separate) implementation of After the first semi-working prototype we could figure out how to implement the switching mechanism between a homogeneous model and a parallelized one, but it's too early to think about that, IMO. What do you think, @stas00 ? |
Amazing! Thank you for this awesome presentation, @anton-l! This could totally be a great blog post - I agree! Let me study the information you shared and I will follow up then! Until then I have a quick suggestion: Do you have an easy access to 2 gpus? That would be enough to make a I suppose it'd be easier to implement this for Megatron-LM, but the main use would be t5 and gpt2 where we have most huge models at the moment. So we could start there as well. If it works for you. Which also can be worked on independently of your Megatron-LM PR. |
Regarding the setup: I can borrow a second gpu for the time being, that shouldn't be a problem :) As for the models, I think GPT2 is a good candidate for our experiments, since the transformers' implementation is already stable and has multiple smaller checkpoints for quick demos. Also, I don't think we should even be too concerned about porting the 8 original splits of fairseq's megatron, since I've already concatenated them for the model's PR. If everything was done correctly, this potentially allows us to create an arbitrary split across 2^n devices, not just 8. |
Sounds good on all accounts. GPT2 would be perfect, @anton-l! I had the same thought about just splitting your merged model if needed. Please let us know how we can support you in this endeavor. just for you to be aware, I mentioned in the other thread the DeepSpeed version of their Megatron-LM port - perhaps theirs is newer - I haven't had a chance to study it yet. https://github.com/jeffra/DSE/tree/master/megatron-lm . You can diff the different versions against the baseline - that is I assume it has been changed - perhaps it hasn't. If you want to have a look, if not, it is good too. It will be good to start anywhere. |
@anton-l Thanks for the great work on this, its really nice to be able to load the pretrained model so thanks for that too! Did you have any progress on fine-tuning across multiple GPUs? Would love to see if the results get any better with some fine-tuning... |
@anton-l, let's do it if you have resources and interest? Let me know how I can be of help. Now having used Megatron-LM in big science experiments it's time to port it to transformers. |
@stas00 @anton-l Just curious, is Megatron-LM now ported to transformers? Or the proof of concept mentioned in:
I would love to work on this issue, if there is anything I could do! |
Thanks for nice the overview. Having read the paper, I disagree with the following statement (emphasis mine)
If you split one layer inputs across rows, then the outputs are split across columns, so you need to split the second layer weights across rows, then you need to gather outputs before applying a non-linearity. This is explained in Section 3 of the paper. |
🚀 Feature request
Splitting the discussion that started here: #10301 (comment) to add the potential future feature of transformers and it's Tensor Parallelism (Horizontal Model Parallelism) - for bigger context please see Parallelism notes.
Let's start with important clarification: MP can mean many different things
At the moment I think it's only Megatron-LM that implements Horizontal MP. @anthon-l has ported that model to
transformers
, except the Horizontal MP parts, since currentlytransformers
doesn't yet have support for it. There is already naive Vertical MP in t5 and gpt2 thanks to @alexorona's work, I ported Bart too but it's unmerged, and there is an ongoing effort to figure out how to implement the Pipeline. All these will have to co-operate with each other and also share common tools.@anton-l started sharing what needs to be done to make that important feature available - and then down the road potentially make it available to other (all?)
transformers
models.@anton-l, the floor is yours.
The text was updated successfully, but these errors were encountered: