Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stuck in trying to combine PP with DeepSpeed #710

Open
stas00 opened this issue Jan 29, 2021 · 9 comments
Open

stuck in trying to combine PP with DeepSpeed #710

stas00 opened this issue Jan 29, 2021 · 9 comments

Comments

@stas00
Copy link
Collaborator

stas00 commented Jan 29, 2021

I have been trying many different ways, but I'm still stuck trying to combine a Pipeline (from pytorch-nightly) with DeepSpeed.

I adapted MPU from https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM/mpu, which by running some tests on seems to produce what I need.

But I can't figure out how to launch DeepSpeed to make things work.

Setup: 4 gpus

2 model parallel groups: [g0, g1], [g2, g3]
2 data parallel groups: [g0, g2], [g1, g3]

Inside the pipeline I correctly switch model layers to either [g0, g1] or [g2, g3].

I pass mpu to deepspeed init as suggested in the docs:

    mpu.initialize_model_parallel(n_gpus_per_mp)
    model, optimizer, _, lr_scheduler = deepspeed.initialize(
        args=SimpleNamespace(**ds_args),  # expects an obj
        model=model,
        model_parameters=model_parameters,
        config_params=config,
        mpu = mpu,
    )

I init the mpu with 2 for 2 gpus per MP. mpu.initialize_model_parallel(2)

Now how do I launch deepspeep to use my PP and to achive 2D Parallelism? Do I tell it about 4 gpus or just gpus 0 and 2?

  1. if I tell it about 4 gpus (well, letting it see all gpus) - it launches 4 processes which I think is wrong, and it crashes with:
CUDA_VISIBLE_DEVICES=0,1,2,3,  deepspeed  program
[...]
  File "/home/stas/hf/transformers/src/transformers/models/t5/modeling_t5.py", line 256, in forward
    return self.weight * hidden_states
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:2 and cuda:1!

gpus 2 and 1 shouldn't end up together in the pipeline - they are in the wrong groups.

  1. if I tell it only about the gpus 0+2 it launches 2 processes, but then crashes with:
CUDA_VISIBLE_DEVICES=0,1,2,3,  deepspeed --include localhost:0,2 program

  0%|                                                                                                                                  | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
  File "./finetune_trainer.py", line 373, in <module>
    main()
  File "./finetune_trainer.py", line 303, in main
    train_result = trainer.train(
  File "/home/stas/hf/transformers/src/transformers/trainer.py", line 919, in train
    tr_loss += self.training_step(model, inputs)
  File "/home/stas/hf/transformers/src/transformers/trainer.py", line 1286, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/stas/hf/transformers/src/transformers/trainer.py", line 1316, in compute_loss
    outputs = model(**inputs)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 702, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/stas/hf/deepspeed/deepspeed/runtime/engine.py", line 824, in forward
    loss = self.module(*inputs, **kwargs)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/stas/hf/transformers/src/transformers/models/t5/modeling_t5.py", line 1833, in forward
    encoder_outputs = self.encoder(
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/stas/hf/transformers/src/transformers/models/t5/modeling_t5.py", line 1144, in forward
    layers[layer_id].to(device_id)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 673, in to
    return self._apply(convert)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 387, in _apply
    module._apply(fn)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 387, in _apply
    module._apply(fn)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 387, in _apply
    module._apply(fn)
  [Previous line repeated 2 more times]
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 409, in _apply
    param_applied = fn(param)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 671, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
RuntimeError: CUDA error: invalid device ordinal
Traceback (most recent call last):
  File "./finetune_trainer.py", line 373, in <module>
    main()
  File "./finetune_trainer.py", line 303, in main
    train_result = trainer.train(
  File "/home/stas/hf/transformers/src/transformers/trainer.py", line 919, in train
    tr_loss += self.training_step(model, inputs)
  File "/home/stas/hf/transformers/src/transformers/trainer.py", line 1286, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/stas/hf/transformers/src/transformers/trainer.py", line 1316, in compute_loss
    outputs = model(**inputs)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 702, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/stas/hf/deepspeed/deepspeed/runtime/engine.py", line 824, in forward
    loss = self.module(*inputs, **kwargs)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/stas/hf/transformers/src/transformers/models/t5/modeling_t5.py", line 1833, in forward
    encoder_outputs = self.encoder(
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/stas/hf/transformers/src/transformers/models/t5/modeling_t5.py", line 1167, in forward
    outputs = block_pipe(inputs)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/distributed/pipeline/sync/pipe.py", line 366, in forward
    return RRef(output)
RuntimeError: agent INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/distributed/rpc/rpc_agent.cpp":247, please report a bug to PyTorch. Current RPC agent is not set!
  1. Running direct tests on your MPU code it has to see the world of size 4 for this to work, but I can only do it by exposing 4 gpus and end up with 4 processes, which crash. mpu functions return wrong groups if the world is size 2.

Somehow I need to tell deepspeed that there are 4 gpus to work with but that I'm using a Pipeline.

So I'm stuck :(

Please help!

Thank you!

p.s. do you by chance have some sort of slack where quick questions of this kind could be asked? I'm sure the answer is very simple, but I couldn't find it anywhere in docs. some of us ask question on pytorch-slack which you will be very gladly invited to I believe - I'm not an admin, but I would be happy to ask if you email me your email and I will invite you - I'm at stas@stason.org.

@ShadenSmith
Copy link
Contributor

Hi @stas00, I haven't worked with torch's pipeline parallelism so I'm not sure of the specific issues you are running into. You shouldn't need to restrict which devices are visible to DeepSpeed. Each GPU should have its own process, and those ranks are grouped into data and model parallel groups.

Here's a sneak preview of our Megatron/DeepSpeed 3D training, and how we setup a 3D topology.
You can also see our MPU modifications in that branch. https://github.com/jeffra/DSE/blob/megatron-deepspeed-pipeline/megatron/initialize.py#L159

@stas00
Copy link
Collaborator Author

stas00 commented Jan 29, 2021

Thank you, Shaden.

This clarification should help me to know how to debug this further.

Thank you for the sneak preview - I will try to derive from the code what's meant to happen.

@stas00
Copy link
Collaborator Author

stas00 commented Jan 30, 2021

I tried to simplify the situation greatly, by using pipeline, but without spreading the layers out over multiple gpus, so the pipeline protocol gets launched but it remains contained in just one gpu, so deepspeed comes in with 2 DP processes and it fails still.

So now I have simplified the world to:

setup: 2 gpus

2 model parallel groups: [g0], [g1]
2 data parallel groups: [g0], [g1],

Now it crashes here:

 Traceback (most recent call last):
  File "./finetune_trainer.py", line 373, in <module>
    main()
  File "./finetune_trainer.py", line 303, in main
    train_result = trainer.train(
  File "/mnt/nvme1/code/huggingface/transformers-mp-pp/src/transformers/trainer.py", line 919, in train
    tr_loss += self.training_step(model, inputs)
  File "/mnt/nvme1/code/huggingface/transformers-mp-pp/src/transformers/trainer.py", line 1300, in training_step
    self.deepspeed.backward(loss)
  File "/mnt/nvme1/code/github/00optimize/DeepSpeed-zombies/deepspeed/runtime/engine.py", line 903, in backward
    self.optimizer.backward(loss)
  File "/mnt/nvme1/code/github/00optimize/DeepSpeed-zombies/deepspeed/runtime/zero/stage2.py", line 1609, in backward
    self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  File "/mnt/nvme1/code/github/00optimize/DeepSpeed-zombies/deepspeed/runtime/fp16/loss_scaler.py", line 53, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/tensor.py", line 225, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/autograd/__init__.py", line 145, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Function CopyBackward returned an invalid gradient at index 5 - expected device cuda:0 but got cuda:1

In this situation there should be no interaction between gpu 0 and gpu 1 - they should be totally unrelated, as they are in 2 different DP groups. But you can see this:

expected device cuda:0 but got cuda:1

The error comes from DeepSpeed.

I wonder if there is some clash between 2 RPC processes? Since pytorch pipeline requires the user to initiate its own RPC)

If I disable RPC, I get:

  File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/distributed/pipeline/sync/pipe.py", line 366, in forward
    return RRef(output)
RuntimeError: agent INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/distributed/rpc/rpc_agent.cpp":247, please report a bug to PyTorch. Current RPC agent is not set!

@ShadenSmith
Copy link
Contributor

This error makes me think that the computation graph is crossing GPUs and autograd is complaining:

RuntimeError: Function CopyBackward returned an invalid gradient at index 5 - expected device cuda:0 but got cuda:1

DeepSpeed moves the model to device local_rank at initialization and then does not touch other GPUs except through NCCL communications (so no computation graph to worry about). Is Pipe moving any data besides that? DeepSpeed has only been tested with our own pipeline engine, and is ultimately designed for distributed computing and assumes that each process only accesses one GPU.

I think there may also be an issue with the model/data parallel groups. For model parallelism with two GPUs, you could have:

1 model parallel group: [g0, g1]
2 data parallel groups: [g0] [g1]

Or pure data parallelism would be:

2 model parallel groups: [g0] [g1]
1 data parallel group: [g0, g1]

@stas00
Copy link
Collaborator Author

stas00 commented Jan 30, 2021

Thank you for this insight, @ShadenSmith - so let me validate that:

I modified your original mpu to add 2 additional functions that return these 2 groups, I call these similar functions that DeepSpeed would call but returning the ids rather than the objects, so it looks correct:

# process 0
DP group [0, 1]
MP group [0]
pipeline partitioning: {0: [0, 1, 2, 3, 4, 5, 6]}

# process 1
DP group [0, 1]
MP group [1]
pipeline partitioning: {1: [0, 1, 2, 3, 4, 5, 6]}

i.e. my pipeline is locked to a single device per process - process 0 to gpu 0 and process 1 to gpu 1, and so DP runs on 0+1. This is on purpose in hope to solve this simpler setup first.

I init the MPU with:

mpu.initialize_model_parallel(1)

I launch deepspeed with just:

deepspeed program

and this machine has only 2 gpus. It does the same if I pass --num_gpus=2

@stas00
Copy link
Collaborator Author

stas00 commented Jan 30, 2021

I don't know if it helps, but I'm using the original MPU and my logging shows that the groups are correct. I just had to convert it into a class and add 2 debug functions:

get_data_parallel_group_device_ids()
get_model_parallel_group_device_ids()

The rest of the code is the same as https://github.com/microsoft/DeepSpeedExamples/blob/master/Megatron-LM/mpu/initialize.py

# Model parallel group that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
_MODEL_PARALLEL_GROUP_DEVICE_IDS = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP_DEVICE_IDS = None

# adjusted from Megatron-LM/mpu/
class MPU:
    def initialize_model_parallel(self, model_parallel_size_):
        """
        Initialize model data parallel groups.

        Arguments:
            model_parallel_size: number of GPUs used to parallelize model.
                                **Important**: not the total number of gpus!

        Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
        use 2 GPUs to parallelize the model. The present function will
        create 4 model parallel groups and 2 data parallel groups as:
            4 model parallel groups:
                [g0, g1], [g2, g3], [g4, g5], [g6, g7]
            2 data parallel groups:
                [g0, g2, g4, g6], [g1, g3, g5, g7]

        Note that for efficiency, the caller should make sure adjacent ranks
        are on the same DGX box. For example if we are using 2 DGX-1 boxes
        with a total of 16 GPUs, rank 0 to 7 belong to the first box and
        ranks 8 to 15 belong to the second box.

        Let's say we have a total of 4 GPUs denoted by g0 ... g3 and we
        use 2 GPUs to parallelize the model. The present function will
        create 2 model parallel groups and 2 data parallel groups as:
            2 model parallel groups:
                [g0, g1], [g2, g3]
            2 data parallel groups:
                [g0, g2], [g1, g3]

        """

        def ensure_divisibility(numerator, denominator):
            """Ensure that numerator is divisible by the denominator."""
            assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)

        if torch.distributed.get_rank() == 0:
            print("> initializing model parallel with size {}".format(model_parallel_size_))
        # Get world size and rank. Ensure some consistencies.
        assert torch.distributed.is_initialized()
        world_size = torch.distributed.get_world_size()
        model_parallel_size = min(model_parallel_size_, world_size)
        ensure_divisibility(world_size, model_parallel_size)
        rank = torch.distributed.get_rank()

        print(f"MP size: {model_parallel_size}")
        print(f"world_size: {world_size}")
        print(f"rank: {rank}")

        # Build the data parallel groups.
        global _DATA_PARALLEL_GROUP
        global _DATA_PARALLEL_GROUP_DEVICE_IDS
        assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
        for i in range(model_parallel_size):
            ranks = range(i, world_size, model_parallel_size)
            group = torch.distributed.new_group(ranks)
            if i == (rank % model_parallel_size):
                #print(f"DP ranks: {list(ranks)}")
                _DATA_PARALLEL_GROUP = group
                _DATA_PARALLEL_GROUP_DEVICE_IDS = list(ranks)

        # Build the model parallel groups.
        global _MODEL_PARALLEL_GROUP
        global _MODEL_PARALLEL_GROUP_DEVICE_IDS
        assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
        for i in range(world_size // model_parallel_size):
            ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size)
            group = torch.distributed.new_group(ranks)
            if i == (rank // model_parallel_size):
                #print(f"MP ranks: {list(ranks)}")
                _MODEL_PARALLEL_GROUP = group
                _MODEL_PARALLEL_GROUP_DEVICE_IDS = list(ranks)

    def model_parallel_is_initialized(self):
        """Check if model and data parallel groups are initialized."""
        if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
            return False
        return True

    def get_model_parallel_group_device_ids(self):
        """Get the model parallel device ids of the group the caller rank belongs to."""
        assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
        return _MODEL_PARALLEL_GROUP_DEVICE_IDS

    def get_model_parallel_group(self):
        """Get the model parallel group the caller rank belongs to."""
        assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
        return _MODEL_PARALLEL_GROUP

    def get_data_parallel_group_device_ids(self):
        """Get the data parallel device ids of the group the caller rank belongs to."""
        assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
        return _DATA_PARALLEL_GROUP_DEVICE_IDS

    def get_data_parallel_group(self):
        """Get the data parallel group the caller rank belongs to."""
        assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
        return _DATA_PARALLEL_GROUP

    def get_model_parallel_world_size(self):
        """Return world size for the model parallel group."""
        return torch.distributed.get_world_size(group=self.get_model_parallel_group())

    def get_model_parallel_rank(self):
        """Return my rank for the model parallel group."""
        return torch.distributed.get_rank(group=self.get_model_parallel_group())

    def get_model_parallel_src_rank(self):
        """Calculate the global rank corresponding to a local rank zero
        in the model parallel group."""
        global_rank = torch.distributed.get_rank()
        local_world_size = get_model_parallel_world_size()
        return (global_rank // local_world_size) * local_world_size

    def get_data_parallel_world_size(self):
        """Return world size for the data parallel group."""
        return torch.distributed.get_world_size(group=self.get_data_parallel_group())

    def get_data_parallel_rank(self):
        """Return my rank for the data parallel group."""
        return torch.distributed.get_rank(group=self.get_data_parallel_group())

    def destroy_model_parallel(self):
        """Set the groups to none."""
        global _MODEL_PARALLEL_GROUP
        global _MODEL_PARALLEL_GROUP_DEVICE_IDS
        global _DATA_PARALLEL_GROUP
        global _DATA_PARALLEL_GROUP_DEVICE_IDS
        _MODEL_PARALLEL_GROUP = None
        _MODEL_PARALLEL_GROUP_DEVICE_IDS = None
        _DATA_PARALLEL_GROUP = None
        _DATA_PARALLEL_GROUP_DEVICE_IDS = None

@stas00
Copy link
Collaborator Author

stas00 commented Jan 30, 2021

I made partial progress. When I was fighting with getting over the limitations of the pipeline API I left a few hardcoded to.(0), and that was the cause of this last failure of 2x1 groups! This is great! Next will do 4 x 1 and then 2 x 2.

I have to start thinking in 2D, clearly I wrote the pipeline with 1D thinking.

@stas00
Copy link
Collaborator Author

stas00 commented Jan 31, 2021

OK, I think I can see the problem now.

So we have 4 gpus and we want to build a 2D Pipeline with:

      pp
dp0 [0, 1]
dp1 [2, 3] 

Now I start deepspeed with 4 gpus, and it starts 4 processes instead of 2!

Obviously my code tries to fire-off 4 pipeline instead of 2 and everything breaks.

Things work just fine in the 2 gpus world with pipeline of depth 1 (no pipeline per se)

    pp
dp0 [0]
dp1 [1]

Since we have just 2 processes, all works out correctly.

So, what's the proper way to launch of a 4-gpu setup on deepspeed, so that Pipeline gets hit only with 2 processes and not 4? Do I just exit(0) those processes that shouldn't be running? No, that doesn't work since the dist is of size 4 and it gets stuck waiting for those 2 processes.

I tried deepspeed --include localhost:0,2 but it doesn't work, as it shrinks the world to 2 gpus.

I hope my explanation is understandable.

Thank you.

@stas00
Copy link
Collaborator Author

stas00 commented Jan 31, 2021

I think I now understand that this of course can't work, since I was following:
https://www.deepspeed.ai/features/#support-for-custom-model-parallelism
and this is not PP but the super-confusing-mean-different-things-in-different-contexts abbreviation MP, which in this particular context means horizontal MP and not vertical MP/PP. And there are no instructions on how to integrate PP. So I have been trying to fix the wrong thing.

OK, so let's start from scratch.

I have a pipeline that is not DeepSpeed pipeline. How to I integrate it with DeepSpeed? Thank you.

  • My guess it that MPU is still the answer - which functions should it have? Will DeepSpeed automatically detect and call them?
  • Specifically to MPU, how does my MPU should tell that it's Pipeline-mode that I want and not MP? That is I want 2D and not 3D. And I want PP+DP and not MP+DP.
  • I suppose that DeepSpeed despite the world of 4 gpus auto-magically will make sure that only 2 processes are running for 2-dp/2-pp setup and not 4 - since this is where all the problems so far came from.

I started looking at your source code and I see that outside PP isn't supported, and by MP you presume horizontal MP.

        if self.mpu is None:
            self.data_parallel_group = _initialize_parameter_parallel_groups()
            self.dp_world_size = dist.get_world_size()
            self.mp_world_size = 1
            self.broadcast_src_rank = 0
        else:
            self.data_parallel_group = self.mpu.get_data_parallel_group()
            self.dp_world_size = self.mpu.get_data_parallel_world_size()
            self.mp_world_size = self.mpu.get_model_parallel_world_size()
            self.broadcast_src_rank = _get_global_rank(
                self.mpu.get_data_parallel_group(),
                0)

So basically I misinterpreted the features.md doc, that I could use non-DeepSpeed PP with DeepSpeed, but it doesn't look like it's possible at the moment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants