Skip to content

Commit

Permalink
Fix missing deepspeed distributed call (#9540)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Naren committed Sep 22, 2021
1 parent 562e18f commit 7f2b9fc
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error reporting in DDP process reconciliation when processes are launched by an external agent ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))


- Fixed missing deepspeed distributed call ([#9540](https://github.com/PyTorchLightning/pytorch-lightning/pull/9540))


## [1.4.5] - 2021-08-31

- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))
Expand Down
30 changes: 18 additions & 12 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple
from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache

Expand Down Expand Up @@ -334,33 +335,38 @@ def _load_config(self, config):
return config

def setup_distributed(self):
super().setup_distributed()
reset_seed()

# determine which process we are and world size
self.set_world_ranks()

self._init_deepspeed_distributed()

# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device
if not self._config_initialized:
self._format_config()
self._config_initialized = True

def init_ddp_connection(self, global_rank: Optional[int] = None, world_size: Optional[int] = None) -> None:
def _init_deepspeed_distributed(self) -> None:
if platform.system() != "Windows":
# do not set env variables on windows, allow deepspeed to control setup
global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
world_size = world_size if world_size is not None else self.cluster_environment.world_size()
self._set_node_environment_variables(global_rank, world_size)
self._set_node_environment_variables()
log.info(
"initializing deepspeed distributed: "
f"GLOBAL_RANK: {global_rank}, "
f"MEMBER: {global_rank + 1}/{world_size}"
f"GLOBAL_RANK: {self.global_rank}, "
f"MEMBER: {self.global_rank + 1}/{self.world_size}"
)
deepspeed.init_distributed(
self.torch_distributed_backend, distributed_port=self.cluster_environment.master_port()
)

def _set_node_environment_variables(
self, global_rank: Optional[int] = None, world_size: Optional[int] = None
) -> None:
def _set_node_environment_variables(self) -> None:
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
os.environ["RANK"] = str(global_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(self.global_rank)
os.environ["WORLD_SIZE"] = str(self.world_size)
os.environ["LOCAL_RANK"] = str(self.local_rank)

@property
Expand Down
10 changes: 7 additions & 3 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tests.helpers.runif import RunIf

if _DEEPSPEED_AVAILABLE:
import deepspeed
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict


Expand Down Expand Up @@ -383,12 +384,15 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module) -> None:

@RunIf(min_gpus=2, deepspeed=True, special=True)
def test_deepspeed_multigpu(tmpdir):
"""Test to ensure that DeepSpeed with multiple GPUs works."""
"""Test to ensure that DeepSpeed with multiple GPUs works and deepspeed distributed is initialized
correctly."""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16
)
trainer.fit(model)
with mock.patch("deepspeed.init_distributed", wraps=deepspeed.init_distributed) as mock_deepspeed_distributed:
trainer.fit(model)
mock_deepspeed_distributed.assert_called_once()
trainer.test(model)

_assert_save_model_is_equal(model, tmpdir, trainer)
Expand Down Expand Up @@ -810,7 +814,7 @@ def test_deepspeed_plugin_env_variables(mock_deepspeed_distributed, tmpdir, plat
plugin = trainer.training_type_plugin
assert isinstance(plugin, DeepSpeedPlugin)
with mock.patch("platform.system", return_value=platform) as mock_platform:
plugin.init_ddp_connection()
plugin._init_deepspeed_distributed()
mock_deepspeed_distributed.assert_called()
mock_platform.assert_called()
if platform == "Windows":
Expand Down

0 comments on commit 7f2b9fc

Please sign in to comment.