Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

global process count incorrect with elastic, fault tolerant training #6853

Closed
srib opened this issue Apr 6, 2021 · 23 comments · Fixed by #6941
Closed

global process count incorrect with elastic, fault tolerant training #6853

srib opened this issue Apr 6, 2021 · 23 comments · Fixed by #6941
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task waiting on author Waiting on user action, correction, or update
Milestone

Comments

@srib
Copy link

srib commented Apr 6, 2021

🐛 Bug

Problem

Count of the total number of processes incorrectly set.

Context

I am trying to run elastic training with torchelastic. I have tried with both gloo and nccl backends.

Error message

Error coming from gloo backend:

Traceback (most recent call last):
  File "train_hydra.py", line 20, in hydra_main
    train(cfg)
  File "/bdata/bdata1/sribkain/learnseis/learnseis/training.py", line 39, in train
    t.fit(module, data_module)
  File "/ldata/Code/salt-identification/SRIBKAIN_ENVS/pl_env/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 496, in fit
    self.pre_dispatch()
  File "/ldata/Code/salt-identification/SRIBKAIN_ENVS/pl_env/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 525, in pre_dispatch
    self.accelerator.pre_dispatch()
  File "/ldata/Code/salt-identification/SRIBKAIN_ENVS/pl_env/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 83, in pre_dispatch
    self.training_type_plugin.pre_dispatch()
  File "/ldata/Code/salt-identification/SRIBKAIN_ENVS/pl_env/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 243, in pre_dispatch
    self.init_ddp_connection(self.global_rank, self.world_size)
  File "/ldata/Code/salt-identification/SRIBKAIN_ENVS/pl_env/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 226, in init_ddp_connection
    torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
  File "/ldata/Code/salt-identification/SRIBKAIN_ENVS/pl_env/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 432, in init_process_group
    timeout=timeout)
  File "/ldata/Code/salt-identification/SRIBKAIN_ENVS/pl_env/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 503, in _new_process_group_helper
    timeout=timeout)
RuntimeError: [enforce fail at /pytorch/third_party/gloo/gloo/context.cc:27] rank < size. 13 vs 8

NCCL backend gives this error: pytorch/pytorch#20313

Please reproduce using the BoringModel

I am running imagenet example from pl using torchvision.models.resnet34. Happy to reproduce with BoringModel if needed.

Before launching, I have exported the variable GLOO_SOCKET_IFNAME and set it to the appropriate interface name.

On node 0:

PL_TORCH_DISTRIBUTED_BACKEND=gloo python -m torchelastic.distributed.launch --nnodes=1:5 --rdzv_id='nodockertestelasticlaunch7' --rdzv_backend=etcd --rdzv_endpoint=10.18.0.15:2379 train_hydra.py +experiment=elastic_config.yaml

On node 1:

PL_TORCH_DISTRIBUTED_BACKEND=gloo python -m torchelastic.distributed.launch --nnodes=1:5 --rdzv_id='nodockertestelasticlaunch7' --rdzv_backend=etcd --rdzv_endpoint=10.18.0.15:2379 train_hydra.py +experiment=elastic_config.yaml

To Reproduce

Use following BoringModel and post here

Expected behavior

To be able to run distributed fault tolerant training :)

Environment

Note: Bugs with code are solved faster ! Colab Notebook should be made public !

Output of collect_env_details.py:

* CUDA:
        - GPU:
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
        - available:         True
        - version:           10.2
* Packages:
        - numpy:             1.19.2
        - pyTorch_debug:     False
        - pyTorch_version:   1.6.0
        - pytorch-lightning: 1.2.6
        - tqdm:              4.48.2
        - torchelastic:    0.2.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                -
        - processor:         x86_64
        - python:            3.7.7
        - version:           #88-Ubuntu SMP Tue Feb 11 20:11:34 UTC 2020

Additional context

@srib srib added bug Something isn't working help wanted Open to be worked on labels Apr 6, 2021
@srib
Copy link
Author

srib commented Apr 6, 2021

Related issues: #6797, #6527

@srib
Copy link
Author

srib commented Apr 6, 2021

Shouldn't torch_distrib.init_process_group use world_size from os.environ['WORLD_SIZE'] instead of world_size?

    def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
        os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
        os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
        os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size())

        if not torch.distributed.is_initialized():
            log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
            torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)

From: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/training_type/ddp.py#L234

@tchaton tchaton added the priority: 0 High priority task label Apr 7, 2021
@ananthsub
Copy link
Contributor

ananthsub commented Apr 7, 2021

@srib what does your os.environ look like before init_ddp_connection is called? could you print this out? Lightning currently doesn't handle the case where min_nodes != max_nodes in torchelastic. The trainer needs to listen for when the world size changes and rebalance accordingly, which currently isn't implemented.

@srib
Copy link
Author

srib commented Apr 7, 2021

@ananthsub Thank you for the reply.

All the following cases were run with the following command:

PL_TORCH_DISTRIBUTED_BACKEND=gloo python -m torchelastic.distributed.launch --nnodes=1:5 --rdzv_id='nodockertestelasticlaunch7' --rdzv_backend=etcd --rdzv_endpoint=10.18.0.15:2379 train_hydra.py +experiment=elastic_config.yaml

It seems to work fine when min_nodes != max_nodes. I ran the following tests and here is what I have to report.

  1. On a single node with 8 GPUs, here is what my os.environ looks like before init_ddp_connection is called:
world_size: 8
os.environ["WORLD_SIZE"]: 8
  1. On two nodes with 8 GPUs each:
world_size: 8
os.environ["WORLD_SIZE"]: 16
  1. Interrupt one node:
world_size: 8
os.environ["WORLD_SIZE"]: 8
  1. Add a new node again:
world_size: 8
os.environ["WORLD_SIZE"]: 16

I even got training to run with the following change:

Before:

torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)

After:

torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=int(os.environ["WORLD_SIZE"]))

Am I missing something here?

@srib
Copy link
Author

srib commented Apr 8, 2021

@ananthsub Can I submit a PR for this or are there any other implications that I am missing here?

@ananthsub
Copy link
Contributor

@awaelchli what do you think of making global_rank a method on the ClusterEnvironment. world_size is already there, and we know whether the cluster environment spawns processes based on creates_children.

Given this state is already available from the env variables, this could be used to set the rank_zero_only.rank on trainer init as well, and we can reset this in the DDP/parallel plugin whenever we do spawn. We can add a setter for world size and global rank as well. For Slurm and Torchelastic, the setter should be a no-op: we can always set the fields to be whatever the environment is and ignore the passed in argument

Then the cluster environment becomes the source of truth for these fields, which can be used in the training type plugin, which is used in the accelerator, which is used in the trainer.

@awaelchli
Copy link
Contributor

Sounds good yes. Let's try it. This will affect all major plugins, so we need to be careful :)

@awaelchli
Copy link
Contributor

@srib ok I think I understand now the bug you are experiencing. The world size is changing in the middle of training, and we are keeping the old value saved. And so we should read the world size always from the cluster environment (which will read it from the environment)

@awaelchli
Copy link
Contributor

Found a related bug in master:

Running
python pl_examples/basic_examples/simple_image_classifier.py --accelerator ddp --gpus 2
then on rank 0 we have LightningEnvironment but on rank 1 it is TorchElasticEnvironment
arghh

@ananthsub
Copy link
Contributor

@awaelchli is it because the lightning environment uses the same environment variables as torchelastic for local rank?

@srib
Copy link
Author

srib commented Apr 9, 2021

@srib ok I think I understand now the bug you are experiencing. The world size is changing in the middle of training, and we are keeping the old value saved. And so we should read the world size always from the cluster environment (which will read it from the environment)

@awaelchli Thanks. If I understand torchelastic's behavior correctly, it updates the os.environ dictionary with all the relevant information about the cluster. So, we should read the cluster environment as you pointed out.

@awaelchli
Copy link
Contributor

@ananthsub yes, I'll try send a PR today. I discovered that while trying your idea with global rank in the cluster environment
@srib to enable this we need to change the rank access in a few places in our plugins. I couldn't get a quick fix yet without other stuff breaking. we need to do it properly like ananth said, using the cluster env in all places.

@srib
Copy link
Author

srib commented Apr 9, 2021

@awaelchli Sounds good. I have a short term workaround for now. I will wait till the fix makes it to the master. If you need any help with testing, please let me know. Appreciate all the effort and help.

@ananthsub
Copy link
Contributor

ananthsub commented Apr 9, 2021

@awaelchli I think this is another issue:
https://github.com/PyTorchLightning/pytorch-lightning/blob/b85cfbe8f350a89c93ff967fa416859be7ebb4f3/pytorch_lightning/plugins/training_type/ddp.py#L283-L285

We shouldn't be messing with these environment variables because we're not distinguishing whether these env vars are set by torchelastic or the lightning spawn environment. Users running with torchelastic would mysteriously see world size disappear

@awaelchli
Copy link
Contributor

true. would be good to know why this is there in the first place ... :(
maybe the least serious bug here, because it happens in post_dispatch so basically at the end of training.
good observation

@ananthsub
Copy link
Contributor

true. would be good to know why this is there in the first place ... :(
maybe the least serious bug here, because it happens in post_dispatch so basically at the end of training.
good observation

the issue is if someone runs trainer.fit() and then trainer.test() right afterward

@awaelchli
Copy link
Contributor

@srib Would you like to try out my branch #6941 ? I believe I have the major stuff sorted out. Just not 100% about all multi node stuff

@awaelchli awaelchli added this to the 1.2.x milestone Apr 12, 2021
@edenlightning edenlightning added the waiting on author Waiting on user action, correction, or update label Apr 12, 2021
@srib
Copy link
Author

srib commented Apr 12, 2021

@awaelchli Sorry for the delay in responding. I had to make a few tweaks to my code as it was using the hydra configs from the latest master.

I can confirm that it works fine. It does assign the global ranks correctly now. Please note that I only tested it with gloo backend. Let me also run a test with nccl backend and I will confirm that it fixes my problem.

@awaelchli
Copy link
Contributor

@srib that's amazing. thanks for confirming, that's very good to know

@srib
Copy link
Author

srib commented Apr 13, 2021

Thanks @awaelchli and @ananthsub for this PR. When is the next release so that I can be lazy and rely on the wheel? 😄

@awaelchli
Copy link
Contributor

The next release is planned for tomorrow. This is the branch that we will release as 1.2.8: #6983

@srib
Copy link
Author

srib commented Apr 14, 2021

Thanks @awaelchli!

@awaelchli
Copy link
Contributor

1.2.8 is out now. hope this will work out well for you on the nccl backend. otherwise let me know
cheers!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants