Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test selecting the correct backend. temp backends while slurm and TE are decoupled #3848

Merged
merged 2 commits into from
Oct 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@
from pytorch_lightning.accelerators.horovod_backend import HorovodBackend
from pytorch_lightning.accelerators.ddp_slurm_backend import DDPSLURMBackend
from pytorch_lightning.accelerators.ddp_torchelastic_backend import DDPTorchElasticBackend
from pytorch_lightning.accelerators.ddp_cpu_torchelastic_backend import DDPCPUTorchElasticBackend
from pytorch_lightning.accelerators.ddp_cpu_slurm_backend import DDPCPUSLURMBackend
9 changes: 9 additions & 0 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def select_accelerator(self):
use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_spawn"
use_ddp_cpu_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_cpu"

use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self._is_using_torchelastic()
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.trainer.is_slurm_managing_tasks

# ddp script mode uses the same flags as TE
# TODO: decouple from TE
if os.environ.get('PL_DDP_PID', False):
Expand All @@ -167,9 +170,15 @@ def select_accelerator(self):
if self.trainer.use_ddp2:
accelerator_backend = accelerators.DDP2Backend(self.trainer)

elif use_ddp_cpu_slurm:
accelerator_backend = accelerators.DDPCPUSLURMBackend(self.trainer)

elif use_slurm_ddp:
accelerator_backend = accelerators.DDPSLURMBackend(self.trainer)

elif use_ddp_cpu_torch_elastic:
accelerator_backend = accelerators.DDPCPUTorchElasticBackend(self.trainer)

elif use_torchelastic_ddp:
accelerator_backend = accelerators.DDPTorchElasticBackend(self.trainer)

Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities.exceptions import MisconfigurationException


try:
Expand Down Expand Up @@ -93,6 +94,9 @@ def _call_children_scripts(self):
# when the trainer script was called the device has already been scoped by the time
# code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone
# but forward the GPUs selected via environment variables
if self.trainer.data_parallel_device_ids is None:
raise MisconfigurationException('you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)')

os.environ['PL_TRAINER_GPUS'] = ','.join([str(i) for i in self.trainer.data_parallel_device_ids])
os.environ['PL_IN_DDP_SUBPROCESS'] = '1'

Expand Down
173 changes: 173 additions & 0 deletions pytorch_lightning/accelerators/ddp_cpu_slurm_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import os
import torch
import torch.distributed as torch_distrib
import torch.distributed as dist

from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.distributed.dist import LightningDistributed


try:
from hydra.utils import to_absolute_path, get_original_cwd
from hydra.core.hydra_config import HydraConfig
except ImportError:
HYDRA_AVAILABLE = False
else:
HYDRA_AVAILABLE = True


# -------------------------------------------
# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!!
# TEMP CLASS WHILE WE DECOUPLE TE FROM DDP
# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!!
# -------------------------------------------
class DDPCPUSLURMBackend(Accelerator):

def __init__(self, trainer, cluster_environment=None):
super().__init__(trainer, cluster_environment)
self.task_idx = None
self._has_spawned_children = False
self.dist = LightningDistributed()

def setup(self, model):
self.trainer.model = model
self.task_idx = int(os.environ['SLURM_LOCALID'])

def train(self):
model = self.trainer.model
self.ddp_train(process_idx=self.task_idx, model=model)

def set_world_ranks(self, process_idx):
self.trainer.local_rank = process_idx
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

def model_to_device(self, model, process_idx):
model.cpu()

def get_device_ids(self):
device_ids = None
return device_ids

def training_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: str = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

def broadcast(self, obj, src=0):
return self.dist.broadcast(obj)

def ddp_train(self, process_idx, model):
"""
Entry point for ddp

Args:
process_idx:
mp_queue: multiprocessing queue
model:

Returns:

"""
# determine which process we are and world size
self.set_world_ranks(process_idx)

# toggle prog bar
if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()

# set warning rank
rank_zero_only.rank = self.trainer.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self.trainer
self.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

# on world_size=0 let everyone know training is starting
if self.trainer.is_global_zero and not torch.distributed.is_initialized():
log.info('-' * 100)
log.info(f'distributed_backend={self.trainer.distributed_backend} (TORCH_ELASTIC)')
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
log.info('-' * 100)

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# move the model to the correct device
self.model_to_device(model, process_idx)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.setup_optimizers(model)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# 16-bit
model = self.trainer.precision_connector.connect(model)

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()

# clean up memory
torch.cuda.empty_cache()

return results
173 changes: 173 additions & 0 deletions pytorch_lightning/accelerators/ddp_cpu_torchelastic_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import os
import torch
import torch.distributed as torch_distrib
import torch.distributed as dist

from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.distributed.dist import LightningDistributed


try:
from hydra.utils import to_absolute_path, get_original_cwd
from hydra.core.hydra_config import HydraConfig
except ImportError:
HYDRA_AVAILABLE = False
else:
HYDRA_AVAILABLE = True


# -------------------------------------------
# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!!
# TEMP CLASS WHILE WE DECOUPLE TE FROM DDP
# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!!
# -------------------------------------------
class DDPCPUTorchElasticBackend(Accelerator):

def __init__(self, trainer, cluster_environment=None):
super().__init__(trainer, cluster_environment)
self.task_idx = None
self._has_spawned_children = False
self.dist = LightningDistributed()

def setup(self, model):
self.trainer.model = model
self.task_idx = int(os.environ['LOCAL_RANK'])

def train(self):
model = self.trainer.model
self.ddp_train(process_idx=self.task_idx, model=model)

def set_world_ranks(self, process_idx):
self.trainer.local_rank = process_idx
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

def model_to_device(self, model, process_idx):
model.cpu()

def get_device_ids(self):
device_ids = None
return device_ids

def training_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: str = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

def broadcast(self, obj, src=0):
return self.dist.broadcast(obj)

def ddp_train(self, process_idx, model):
"""
Entry point for ddp

Args:
process_idx:
mp_queue: multiprocessing queue
model:

Returns:

"""
# determine which process we are and world size
self.set_world_ranks(process_idx)

# toggle prog bar
if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()

# set warning rank
rank_zero_only.rank = self.trainer.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self.trainer
self.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

# on world_size=0 let everyone know training is starting
if self.trainer.is_global_zero and not torch.distributed.is_initialized():
log.info('-' * 100)
log.info(f'distributed_backend={self.trainer.distributed_backend} (TORCH_ELASTIC)')
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
log.info('-' * 100)

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# move the model to the correct device
self.model_to_device(model, process_idx)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.setup_optimizers(model)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# 16-bit
model = self.trainer.precision_connector.connect(model)

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()

# clean up memory
torch.cuda.empty_cache()

return results
Loading