Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: move specific accelerator code x/n #3457

Merged
merged 4 commits into from
Sep 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,6 @@ def _clip_gradients(self, optimizer):

def on_train_epoch_end(self):
pass

def early_stopping_should_stop(self, pl_module):
return self.trainer.should_stop
8 changes: 8 additions & 0 deletions pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.accelerators.base_backend import Accelerator
import torch.distributed as torch_distrib
import torch.distributed as dist

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand Down Expand Up @@ -199,3 +200,10 @@ def test_step_end(self, output):

def barrier(self, name: str = None):
torch_distrib.barrier()

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop
8 changes: 8 additions & 0 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port
from pytorch_lightning.accelerators.base_backend import Accelerator
import torch.distributed as torch_distrib
import torch.distributed as dist

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand Down Expand Up @@ -278,3 +279,10 @@ def _check_can_spawn_children(self):

def barrier(self, name: str = None):
torch_distrib.barrier()

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop
8 changes: 8 additions & 0 deletions pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port
from pytorch_lightning.accelerators.base_backend import Accelerator
import torch.distributed as torch_distrib
import torch.distributed as dist

try:
from apex import amp
Expand Down Expand Up @@ -195,3 +196,10 @@ def test_step(self, args):

def barrier(self, name: str = None):
torch_distrib.barrier()

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop
7 changes: 7 additions & 0 deletions pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,10 @@ def clip_gradients(self, optimizer):

def barrier(self, name: str = None):
torch_xla.core.xla_model.rendezvous(f"pl.Trainer.{name}")

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device, dtype=torch.int32)
stop = xm.mesh_reduce("stop_signal", stop, sum)
torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
should_stop = int(stop.item()) == self.trainer.world_size
return should_stop
19 changes: 2 additions & 17 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
"""
import numpy as np
import torch
import torch.distributed as dist

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
Expand Down Expand Up @@ -216,22 +215,8 @@ def _run_early_stopping_check(self, trainer, pl_module):
trainer.should_stop = True

# stop every ddp process if any world process decides to stop
self._stop_distributed_training(trainer, pl_module)

def _stop_distributed_training(self, trainer, pl_module):

# in ddp make sure all processes stop when one is flagged
if trainer.use_ddp or trainer.use_ddp2:
stop = torch.tensor(int(trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
trainer.should_stop = stop == trainer.world_size

if trainer.use_tpu:
stop = torch.tensor(int(trainer.should_stop), device=pl_module.device, dtype=torch.int32)
stop = xm.mesh_reduce("stop_signal", stop, sum)
torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
trainer.should_stop = int(stop.item()) == trainer.world_size
should_stop = trainer.accelerator_backend.early_stopping_should_stop(pl_module)
trainer.should_stop = should_stop

def on_train_end(self, trainer, pl_module):
if self.stopped_epoch > 0 and self.verbose > 0:
Expand Down
32 changes: 4 additions & 28 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,28 +117,14 @@
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
from pytorch_lightning.accelerators.base_backend import Accelerator

try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True

try:
from apex import amp
except ImportError:
amp = None

try:
import horovod.torch as hvd
except (ModuleNotFoundError, ImportError):
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True

try:
from omegaconf import Container
except ImportError:
Expand Down Expand Up @@ -171,6 +157,7 @@ class TrainerIOMixin(ABC):
scaler: ...
use_tpu: bool
amp_backend: AMPType
accelerator_backend: Accelerator

def get_model(self):
is_dp_module = isinstance(self.model, (LightningDistributedDataParallel, LightningDataParallel))
Expand Down Expand Up @@ -202,19 +189,8 @@ def restore_weights(self, model: LightningModule):
if self.resume_from_checkpoint is not None:
self.restore(self.resume_from_checkpoint, on_gpu=self.on_gpu)

# wait for all models to restore weights
if self.use_ddp or self.use_ddp2:
# wait for all processes to catch up
torch_distrib.barrier()

# wait for all models to restore weights
if self.on_tpu and XLA_AVAILABLE:
# wait for all processes to catch up
torch_xla.core.xla_model.rendezvous("pl.TrainerIOMixin.restore_weights")

elif self.use_horovod:
# wait for all processes to catch up
hvd.join()
# wait for all to catch up
self.accelerator_backend.barrier('TrainerIOMixin.restore_weights')

# clear cache after restore
if self.on_gpu:
Expand Down
16 changes: 0 additions & 16 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,6 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.memory import ModelSummary

try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True

try:
import horovod.torch as hvd
except (ModuleNotFoundError, ImportError):
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True


class TrainLoop:

Expand Down