Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Better support for rank_zero_only setting for SLURM and torchelastic #6802

Merged
merged 22 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Set better defaults for `rank_zero_only.rank` when training is launched with SLURM and torchelastic ([#6802](https://github.com/PyTorchLightning/pytorch-lightning/pull/6802/))


- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))


Expand All @@ -196,7 +199,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506))


## [1.2.6] - 2021-03-30

### Changed
Expand Down
12 changes: 11 additions & 1 deletion pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,18 @@ def wrapped_fn(*args, **kwargs):
return wrapped_fn


# TODO: this should be part of the cluster environment
def _get_rank() -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this directly to SLURM cluster environment ?

What if RANK, SLURM_PROCID or LOCAL_RANK are different ? Should we take the latest or did you order rank_keys based on priority ?

Best,
T.C

Copy link
Contributor Author

@ananthsub ananthsub Apr 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RANK = torchelastic
SLURM_PROCID = slurm
LOCAL_RANK = parity with existing setup though I think it's not right

I set local rank last as RANK and SLURM_PROCID correspond to global rank already. The linked issue has more discussion, but I think we should make global rank and world size properties of the cluster environment. So the cluster environment becomes the source of truth propagating from Cluster environment => training type plugin => accelerator => trainer.

The main issue now is the global rank isn't set on trainer initialization. If the cluster environment is marked as creating children, then we can leave the initialization of these fields for later, but both torchelastic and slurm have this data already available in the environment variables, and we should expose that as soon as possible (on Trainer init) for users to read this state.

Currently, this waits for trainer.fit() to be called, going through the accelerator setup flow for these properties to be initialized on the training type plugin.

rank_keys = ('RANK', 'SLURM_PROCID', 'LOCAL_RANK')
for key in rank_keys:
rank = os.environ.get(key)
if rank is not None:
return int(rank)
return 0


# add the attribute to the function but don't overwrite in case Trainer has already set it
rank_zero_only.rank = getattr(rank_zero_only, 'rank', int(os.environ.get('LOCAL_RANK', 0)))
rank_zero_only.rank = getattr(rank_zero_only, 'rank', _get_rank())


def _warn(*args, **kwargs):
Expand Down
67 changes: 67 additions & 0 deletions tests/utilities/test_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock

import pytest


@mock.patch.dict(os.environ, {"SLURM_PROCID": "0"})
def test_rank_zero_slurm():
""" Test that SLURM environment variables are properly checked for rank_zero_only. """
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
rank_zero_only.rank = _get_rank()

@rank_zero_only
def foo():
# The return type is optional because on non-zero ranks it will not be called
return 1

x = foo()
assert x == 1


@mock.patch.dict(os.environ, {"RANK": "0"})
def test_rank_zero_torchelastic():
""" Test that torchelastic environment variables are properly checked for rank_zero_only. """
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
rank_zero_only.rank = _get_rank()

@rank_zero_only
def foo():
# The return type is optional because on non-zero ranks it will not be called
return 1

x = foo()
assert x == 1
ananthsub marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("rank_key,rank", [
("RANK", "1"),
("SLURM_PROCID", "2"),
("LOCAL_RANK", "3"),
])
def test_rank_zero_none_set(rank_key, rank):
""" Test that function is not called when rank environment variables are not global zero. """

with mock.patch.dict(os.environ, {rank_key: rank}):
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
rank_zero_only.rank = _get_rank()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

@rank_zero_only
def foo():
return 1

x = foo()
assert x is None