-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
stuck in trying to combine PP with DeepSpeed #710
Comments
Hi @stas00, I haven't worked with torch's pipeline parallelism so I'm not sure of the specific issues you are running into. You shouldn't need to restrict which devices are visible to DeepSpeed. Each GPU should have its own process, and those ranks are grouped into data and model parallel groups. Here's a sneak preview of our Megatron/DeepSpeed 3D training, and how we setup a 3D topology. |
Thank you, Shaden. This clarification should help me to know how to debug this further. Thank you for the sneak preview - I will try to derive from the code what's meant to happen. |
I tried to simplify the situation greatly, by using pipeline, but without spreading the layers out over multiple gpus, so the pipeline protocol gets launched but it remains contained in just one gpu, so deepspeed comes in with 2 DP processes and it fails still. So now I have simplified the world to: setup: 2 gpus2 model parallel groups: [g0], [g1] Now it crashes here:
In this situation there should be no interaction between gpu 0 and gpu 1 - they should be totally unrelated, as they are in 2 different DP groups. But you can see this:
The error comes from DeepSpeed. I wonder if there is some clash between 2 RPC processes? Since pytorch pipeline requires the user to initiate its own RPC) If I disable RPC, I get:
|
This error makes me think that the computation graph is crossing GPUs and autograd is complaining:
DeepSpeed moves the model to device I think there may also be an issue with the model/data parallel groups. For model parallelism with two GPUs, you could have: 1 model parallel group: [g0, g1] Or pure data parallelism would be: 2 model parallel groups: [g0] [g1] |
Thank you for this insight, @ShadenSmith - so let me validate that: I modified your original mpu to add 2 additional functions that return these 2 groups, I call these similar functions that DeepSpeed would call but returning the ids rather than the objects, so it looks correct:
i.e. my pipeline is locked to a single device per process - process 0 to gpu 0 and process 1 to gpu 1, and so DP runs on 0+1. This is on purpose in hope to solve this simpler setup first. I init the MPU with:
I launch deepspeed with just:
and this machine has only 2 gpus. It does the same if I pass |
I don't know if it helps, but I'm using the original MPU and my logging shows that the groups are correct. I just had to convert it into a class and add 2 debug functions:
The rest of the code is the same as https://github.com/microsoft/DeepSpeedExamples/blob/master/Megatron-LM/mpu/initialize.py
|
I made partial progress. When I was fighting with getting over the limitations of the pipeline API I left a few hardcoded I have to start thinking in 2D, clearly I wrote the pipeline with 1D thinking. |
OK, I think I can see the problem now. So we have 4 gpus and we want to build a 2D Pipeline with:
Now I start deepspeed with 4 gpus, and it starts 4 processes instead of 2! Obviously my code tries to fire-off 4 pipeline instead of 2 and everything breaks. Things work just fine in the 2 gpus world with pipeline of depth 1 (no pipeline per se)
Since we have just 2 processes, all works out correctly. So, what's the proper way to launch of a 4-gpu setup on deepspeed, so that Pipeline gets hit only with 2 processes and not 4? Do I just I tried I hope my explanation is understandable. Thank you. |
I think I now understand that this of course can't work, since I was following: OK, so let's start from scratch. I have a pipeline that is not DeepSpeed pipeline. How to I integrate it with DeepSpeed? Thank you.
I started looking at your source code and I see that outside PP isn't supported, and by MP you presume horizontal MP.
So basically I misinterpreted the features.md doc, that I could use non-DeepSpeed PP with DeepSpeed, but it doesn't look like it's possible at the moment. |
I have been trying many different ways, but I'm still stuck trying to combine a Pipeline (from pytorch-nightly) with DeepSpeed.
I adapted MPU from https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM/mpu, which by running some tests on seems to produce what I need.
But I can't figure out how to launch DeepSpeed to make things work.
Setup: 4 gpus
2 model parallel groups: [g0, g1], [g2, g3]
2 data parallel groups: [g0, g2], [g1, g3]
Inside the pipeline I correctly switch model layers to either [g0, g1] or [g2, g3].
I pass mpu to deepspeed init as suggested in the docs:
I init the mpu with
2
for 2 gpus per MP.mpu.initialize_model_parallel(2)
Now how do I launch deepspeep to use my PP and to achive 2D Parallelism? Do I tell it about 4 gpus or just gpus 0 and 2?
gpus 2 and 1 shouldn't end up together in the pipeline - they are in the wrong groups.
Somehow I need to tell deepspeed that there are 4 gpus to work with but that I'm using a Pipeline.
So I'm stuck :(
Please help!
Thank you!
p.s. do you by chance have some sort of slack where quick questions of this kind could be asked? I'm sure the answer is very simple, but I couldn't find it anywhere in docs. some of us ask question on pytorch-slack which you will be very gladly invited to I believe - I'm not an admin, but I would be happy to ask if you email me your email and I will invite you - I'm at stas@stason.org.
The text was updated successfully, but these errors were encountered: