From 89f284d6fbafcf6aadac7abf1af08ee3fba39865 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 23 Mar 2021 12:06:24 -0700 Subject: [PATCH 01/49] Fix some test errors Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 ++ tests/core/test_metric_result_integration.py | 3 +++ tests/core/test_results.py | 4 +++- tests/metrics/utils.py | 2 +- tests/utilities/test_all_gather_grad.py | 3 ++- 5 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index ea1efd6e15873..1383964cbc789 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,6 +21,7 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer +import numpy from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -78,6 +79,7 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0b797dff0e42f..ffbe508816403 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,6 +16,8 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric +import numpy +import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -96,6 +98,7 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 9586344d8c0d9..74c4a0c212564 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,11 +26,12 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf +import os +import numpy def _setup_ddp(rank, worldsize): import os - os.environ["MASTER_ADDR"] = "localhost" # initialize the process group @@ -51,6 +52,7 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index f1f17d0624936..4aac65257a504 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 259f9f4c09871..a9f38a9e1d88c 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) @@ -44,6 +44,7 @@ def _test_all_gather_ddp(rank, world_size): @RunIf(skip_windows=True) def test_all_gather_ddp(): world_size = 3 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) From 536c1323b0e6715fb5919196ea48b0fcddddcd66 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 24 Mar 2021 01:17:20 -0700 Subject: [PATCH 02/49] checkpoint consolidation --- pytorch_lightning/callbacks/base.py | 4 +++ pytorch_lightning/callbacks/early_stopping.py | 15 ++++++++ .../callbacks/lambda_function.py | 3 ++ .../callbacks/model_checkpoint.py | 31 ++++++++++++++++ pytorch_lightning/trainer/callback_hook.py | 7 ++++ .../callback_hook_validator.py | 5 +++ pytorch_lightning/trainer/training_loop.py | 35 ++----------------- tests/checkpointing/test_model_checkpoint.py | 35 +++++++++++++++---- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 + 10 files changed, 99 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index db507fa991446..ffb26f38ca821 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,6 +109,10 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the epoch ends.""" pass + def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: + """Called when at the very end of train epoch.""" + pass + def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4448de8e4834b..0de8ff6f0b505 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,6 +143,21 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + from pytorch_lightning.trainer.states import TrainerState + if ( + trainer.state != TrainerState.FITTING or trainer.sanity_checking + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we run early stopping + # at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self._run_early_stopping_check(trainer) + def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 58324e363cd37..2a56e1c8ac6e0 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,6 +53,7 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, + on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -155,3 +156,5 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad + if on_train_epoch_final_end is not None: + self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2a0c108ba7603..9436720e3819b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,6 +238,37 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + """ + at the end of each training epoch, checkpoint only when validation is skipped or disabled + """ + print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) + if ( + self._should_skip_saving_checkpoint(trainer) + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we checkpoint at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self.save_checkpoint(trainer) + + def on_train_end(self, trainer, *args, **kwargs) -> None: + """ + checkpoints can be saved at the end of the trianing + """ + trainer.global_step -= 1 + if ( + not self._should_skip_saving_checkpoint(trainer) + and trainer.checkpoint_connector.has_trained + ): + if self.save_last and self.verbose: + rank_zero_info("Saving latest checkpoint...") + self.save_checkpoint(trainer) + trainer.global_step += 1 + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8823d48a7817e..c53c21ad04bc3 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,6 +92,13 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) + def on_train_epoch_final_end(self) -> None: + """ + Called when at the very end of train epoch. + """ + for callback in self.callbacks: + callback.on_train_epoch_final_end(self, self.lightning_module) + def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index 534dad5199e9b..e7884124df314 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,6 +100,11 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod + def _on_train_epoch_final_end_log(): + """Called when at the very end of train epoch.""" + return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c3ba34ca66d2d..1d498a0a9ff6c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,12 +121,6 @@ def on_train_end(self): return self._teardown_already_run = True - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_update=True, is_last=True) - self.trainer.global_step += 1 - # hook self.trainer.call_hook("on_train_end") @@ -145,28 +139,6 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None - def check_checkpoint_callback(self, should_update, is_last=False): - # TODO bake this logic into the ModelCheckpoint callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = self.trainer.checkpoint_callbacks - - if is_last and any(cb.save_last and cb.verbose for cb in callbacks): - rank_zero_info("Saving latest checkpoint...") - - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - - def check_early_stopping_callback(self, should_update): - # TODO bake this logic into the EarlyStopping callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -562,15 +534,14 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - if should_train_only: - self.check_checkpoint_callback(True) - self.check_early_stopping_callback(True) - if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True + if should_train_only: + self.trainer.call_hook('on_train_epoch_final_end') + # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 75f25b90fa45f..e0c295a843a21 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,7 +609,13 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] + if period > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -631,8 +637,14 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -659,8 +671,14 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -816,10 +834,15 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, + val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + if verbose and save_last and not should_validate: + # no validation, hence checkpoint triggered at the end of each training epoch + assert caplog.messages.count('Saving latest checkpoint...') == False + else: + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f5c1726a423bb..493d32d3fe454 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = "29501" def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 3db0a8eaa065b..b2727177bcacd 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,6 +300,7 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', + 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From f17210183b84f90c9a62d1ff9b3e05e1fbe5f33b Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:37:52 -0700 Subject: [PATCH 03/49] Update ddp_spawn.py --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 941025b36c0ac..87d7fa5faecac 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,7 +21,6 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer -import numpy from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -79,7 +78,6 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") From bf70e431b3ce4893de804e0f3b5d59e79346d6d7 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:41:33 -0700 Subject: [PATCH 04/49] Update test_metric_result_integration.py --- tests/core/test_metric_result_integration.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ffbe508816403..0b797dff0e42f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,8 +16,6 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric -import numpy -import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -98,7 +96,6 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) From ea749068785bbad689a12066544893b1605f20c5 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:42:16 -0700 Subject: [PATCH 05/49] Update test_results.py --- tests/core/test_results.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 74c4a0c212564..f25ab0c40a6ea 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,8 +26,6 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf -import os -import numpy def _setup_ddp(rank, worldsize): @@ -52,7 +50,6 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 From a9aae99f6ed6e9388ecf1d8a7bd79966176a65af Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:43:04 -0700 Subject: [PATCH 06/49] Update utils.py --- tests/helpers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): From 70fe5da9c66ceff2fcf4be5b9efdd23a9af8389c Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:43:43 -0700 Subject: [PATCH 07/49] Update utils.py --- tests/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 4aac65257a504..f1f17d0624936 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 0d23d75bc91e4e0b7805712e394cb093fac22841 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:44:18 -0700 Subject: [PATCH 08/49] Update test_all_gather_grad.py --- tests/utilities/test_all_gather_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index a9f38a9e1d88c..f1860b10326e9 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From ca6f98ba8ff835368ae3ef91e435e4d4f458c45b Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:51:55 -0700 Subject: [PATCH 09/49] Update test_all_gather_grad.py --- tests/utilities/test_all_gather_grad.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index f1860b10326e9..259f9f4c09871 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -44,7 +44,6 @@ def _test_all_gather_ddp(rank, world_size): @RunIf(skip_windows=True) def test_all_gather_ddp(): world_size = 3 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) From 9d4a2b891d2a4b37e21529a444bda1883d1b5ed1 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 01:57:31 -0700 Subject: [PATCH 10/49] Update test_results.py --- tests/core/test_results.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index f25ab0c40a6ea..334062ae994a2 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -30,6 +30,7 @@ def _setup_ddp(rank, worldsize): import os + os.environ["MASTER_ADDR"] = "localhost" # initialize the process group From 7635b4f47bcef43b9bbe677ad96a3bad135246a5 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:10 -0700 Subject: [PATCH 11/49] Revert "Update test_results.py" This reverts commit 9d4a2b891d2a4b37e21529a444bda1883d1b5ed1. --- tests/core/test_results.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 334062ae994a2..f25ab0c40a6ea 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -30,7 +30,6 @@ def _setup_ddp(rank, worldsize): import os - os.environ["MASTER_ADDR"] = "localhost" # initialize the process group From d64f90cbc748de193a02237acd6ac686750b82d1 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:20 -0700 Subject: [PATCH 12/49] Revert "Merge pull request #1 from shuyingsunshine21/shuyingsunshine21-checkpoint_consolidate" This reverts commit c5053da789f9d04d2c967a65adf4fb026dc134b8, reversing changes made to 0d23d75bc91e4e0b7805712e394cb093fac22841. --- tests/utilities/test_all_gather_grad.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 259f9f4c09871..f1860b10326e9 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -44,6 +44,7 @@ def _test_all_gather_ddp(rank, world_size): @RunIf(skip_windows=True) def test_all_gather_ddp(): world_size = 3 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) From dcdcd29731061c919b15ab0b56669259817a81c4 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:36 -0700 Subject: [PATCH 13/49] Revert "Update test_all_gather_grad.py" This reverts commit 0d23d75bc91e4e0b7805712e394cb093fac22841. --- tests/utilities/test_all_gather_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index f1860b10326e9..a9f38a9e1d88c 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 8651d54d79396eaaba16d7eb1e769a1e91d5702e Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:40 -0700 Subject: [PATCH 14/49] Revert "Update utils.py" This reverts commit 70fe5da9c66ceff2fcf4be5b9efdd23a9af8389c. --- tests/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index f1f17d0624936..4aac65257a504 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 15f4b9e59bec52b07dddb55eeda4d9a68b8bd6d2 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:45 -0700 Subject: [PATCH 15/49] Revert "Update utils.py" This reverts commit a9aae99f6ed6e9388ecf1d8a7bd79966176a65af. --- tests/helpers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f5c1726a423bb..493d32d3fe454 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = "29501" def init_checkpoint_callback(logger): From 250d0aaaa2e6c6a6a3407bc6c8b83c0fe2479c0b Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:48 -0700 Subject: [PATCH 16/49] Revert "Update test_results.py" This reverts commit ea749068785bbad689a12066544893b1605f20c5. --- tests/core/test_results.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index f25ab0c40a6ea..74c4a0c212564 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,6 +26,8 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf +import os +import numpy def _setup_ddp(rank, worldsize): @@ -50,6 +52,7 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 From 6c095b2370a2afe9d24918a5798ce1ebffed7e0d Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:52 -0700 Subject: [PATCH 17/49] Revert "Update test_metric_result_integration.py" This reverts commit bf70e431b3ce4893de804e0f3b5d59e79346d6d7. --- tests/core/test_metric_result_integration.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0b797dff0e42f..ffbe508816403 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,6 +16,8 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric +import numpy +import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -96,6 +98,7 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) From 8222dc98ead37d961a52b7366070aa10f66d92d1 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:06:55 -0700 Subject: [PATCH 18/49] Revert "Update ddp_spawn.py" This reverts commit f17210183b84f90c9a62d1ff9b3e05e1fbe5f33b. --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 87d7fa5faecac..941025b36c0ac 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,6 +21,7 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer +import numpy from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -78,6 +79,7 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") From 3a9fde915ad4c69620a6ccc411f5890cb38ba5ac Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:17:01 -0700 Subject: [PATCH 19/49] Revert "checkpoint consolidation" This reverts commit 536c1323b0e6715fb5919196ea48b0fcddddcd66. --- pytorch_lightning/callbacks/base.py | 4 --- pytorch_lightning/callbacks/early_stopping.py | 15 -------- .../callbacks/lambda_function.py | 3 -- .../callbacks/model_checkpoint.py | 31 ---------------- pytorch_lightning/trainer/callback_hook.py | 7 ---- .../callback_hook_validator.py | 5 --- pytorch_lightning/trainer/training_loop.py | 35 +++++++++++++++++-- tests/checkpointing/test_model_checkpoint.py | 35 ++++--------------- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 - 10 files changed, 39 insertions(+), 99 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index ffb26f38ca821..db507fa991446 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,10 +109,6 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the epoch ends.""" pass - def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: - """Called when at the very end of train epoch.""" - pass - def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 0de8ff6f0b505..4448de8e4834b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,21 +143,6 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - from pytorch_lightning.trainer.states import TrainerState - if ( - trainer.state != TrainerState.FITTING or trainer.sanity_checking - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we run early stopping - # at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self._run_early_stopping_check(trainer) - def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 2a56e1c8ac6e0..58324e363cd37 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,7 +53,6 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, - on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -156,5 +155,3 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad - if on_train_epoch_final_end is not None: - self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9436720e3819b..2a0c108ba7603 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,37 +238,6 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - """ - at the end of each training epoch, checkpoint only when validation is skipped or disabled - """ - print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) - if ( - self._should_skip_saving_checkpoint(trainer) - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we checkpoint at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self.save_checkpoint(trainer) - - def on_train_end(self, trainer, *args, **kwargs) -> None: - """ - checkpoints can be saved at the end of the trianing - """ - trainer.global_step -= 1 - if ( - not self._should_skip_saving_checkpoint(trainer) - and trainer.checkpoint_connector.has_trained - ): - if self.save_last and self.verbose: - rank_zero_info("Saving latest checkpoint...") - self.save_checkpoint(trainer) - trainer.global_step += 1 - def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index c53c21ad04bc3..8823d48a7817e 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,13 +92,6 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) - def on_train_epoch_final_end(self) -> None: - """ - Called when at the very end of train epoch. - """ - for callback in self.callbacks: - callback.on_train_epoch_final_end(self, self.lightning_module) - def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index e7884124df314..534dad5199e9b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,11 +100,6 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod - def _on_train_epoch_final_end_log(): - """Called when at the very end of train epoch.""" - return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1d498a0a9ff6c..c3ba34ca66d2d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,6 +121,12 @@ def on_train_end(self): return self._teardown_already_run = True + # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates + # when a checkpoint was saved at the last step + self.trainer.global_step -= 1 + self.check_checkpoint_callback(should_update=True, is_last=True) + self.trainer.global_step += 1 + # hook self.trainer.call_hook("on_train_end") @@ -139,6 +145,28 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None + def check_checkpoint_callback(self, should_update, is_last=False): + # TODO bake this logic into the ModelCheckpoint callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = self.trainer.checkpoint_callbacks + + if is_last and any(cb.save_last and cb.verbose for cb in callbacks): + rank_zero_info("Saving latest checkpoint...") + + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + + def check_early_stopping_callback(self, should_update): + # TODO bake this logic into the EarlyStopping callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -534,14 +562,15 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + if should_train_only: + self.check_checkpoint_callback(True) + self.check_early_stopping_callback(True) + if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True - if should_train_only: - self.trainer.call_hook('on_train_epoch_final_end') - # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e0c295a843a21..75f25b90fa45f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,13 +609,7 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] - if period > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -637,14 +631,8 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -671,14 +659,8 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -834,15 +816,10 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, - val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - if verbose and save_last and not should_validate: - # no validation, hence checkpoint triggered at the end of each training epoch - assert caplog.messages.count('Saving latest checkpoint...') == False - else: - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index b2727177bcacd..3db0a8eaa065b 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,7 +300,6 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', - 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From 7a369f47e1a94d701fce48c994cc3f2da266dad0 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:19:37 -0700 Subject: [PATCH 20/49] Revert "Revert "checkpoint consolidation"" This reverts commit 3a9fde915ad4c69620a6ccc411f5890cb38ba5ac. --- pytorch_lightning/callbacks/base.py | 4 +++ pytorch_lightning/callbacks/early_stopping.py | 15 ++++++++ .../callbacks/lambda_function.py | 3 ++ .../callbacks/model_checkpoint.py | 31 ++++++++++++++++ pytorch_lightning/trainer/callback_hook.py | 7 ++++ .../callback_hook_validator.py | 5 +++ pytorch_lightning/trainer/training_loop.py | 35 ++----------------- tests/checkpointing/test_model_checkpoint.py | 35 +++++++++++++++---- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 + 10 files changed, 99 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index db507fa991446..ffb26f38ca821 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,6 +109,10 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the epoch ends.""" pass + def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: + """Called when at the very end of train epoch.""" + pass + def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4448de8e4834b..0de8ff6f0b505 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,6 +143,21 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + from pytorch_lightning.trainer.states import TrainerState + if ( + trainer.state != TrainerState.FITTING or trainer.sanity_checking + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we run early stopping + # at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self._run_early_stopping_check(trainer) + def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 58324e363cd37..2a56e1c8ac6e0 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,6 +53,7 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, + on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -155,3 +156,5 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad + if on_train_epoch_final_end is not None: + self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2a0c108ba7603..9436720e3819b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,6 +238,37 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + """ + at the end of each training epoch, checkpoint only when validation is skipped or disabled + """ + print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) + if ( + self._should_skip_saving_checkpoint(trainer) + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we checkpoint at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self.save_checkpoint(trainer) + + def on_train_end(self, trainer, *args, **kwargs) -> None: + """ + checkpoints can be saved at the end of the trianing + """ + trainer.global_step -= 1 + if ( + not self._should_skip_saving_checkpoint(trainer) + and trainer.checkpoint_connector.has_trained + ): + if self.save_last and self.verbose: + rank_zero_info("Saving latest checkpoint...") + self.save_checkpoint(trainer) + trainer.global_step += 1 + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8823d48a7817e..c53c21ad04bc3 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,6 +92,13 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) + def on_train_epoch_final_end(self) -> None: + """ + Called when at the very end of train epoch. + """ + for callback in self.callbacks: + callback.on_train_epoch_final_end(self, self.lightning_module) + def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index 534dad5199e9b..e7884124df314 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,6 +100,11 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod + def _on_train_epoch_final_end_log(): + """Called when at the very end of train epoch.""" + return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c3ba34ca66d2d..1d498a0a9ff6c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,12 +121,6 @@ def on_train_end(self): return self._teardown_already_run = True - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_update=True, is_last=True) - self.trainer.global_step += 1 - # hook self.trainer.call_hook("on_train_end") @@ -145,28 +139,6 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None - def check_checkpoint_callback(self, should_update, is_last=False): - # TODO bake this logic into the ModelCheckpoint callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = self.trainer.checkpoint_callbacks - - if is_last and any(cb.save_last and cb.verbose for cb in callbacks): - rank_zero_info("Saving latest checkpoint...") - - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - - def check_early_stopping_callback(self, should_update): - # TODO bake this logic into the EarlyStopping callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -562,15 +534,14 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - if should_train_only: - self.check_checkpoint_callback(True) - self.check_early_stopping_callback(True) - if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True + if should_train_only: + self.trainer.call_hook('on_train_epoch_final_end') + # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 75f25b90fa45f..e0c295a843a21 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,7 +609,13 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] + if period > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -631,8 +637,14 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -659,8 +671,14 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -816,10 +834,15 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, + val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + if verbose and save_last and not should_validate: + # no validation, hence checkpoint triggered at the end of each training epoch + assert caplog.messages.count('Saving latest checkpoint...') == False + else: + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f5c1726a423bb..493d32d3fe454 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = "29501" def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 3db0a8eaa065b..b2727177bcacd 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,6 +300,7 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', + 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From b4a0b9e9e1e0a08e50979facc2f0fc74187de2ee Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 02:32:43 -0700 Subject: [PATCH 21/49] Revert "Revert "Revert "checkpoint consolidation""" This reverts commit 7a369f47e1a94d701fce48c994cc3f2da266dad0. --- pytorch_lightning/callbacks/base.py | 4 --- pytorch_lightning/callbacks/early_stopping.py | 15 -------- .../callbacks/lambda_function.py | 3 -- .../callbacks/model_checkpoint.py | 31 ---------------- pytorch_lightning/trainer/callback_hook.py | 7 ---- .../callback_hook_validator.py | 5 --- pytorch_lightning/trainer/training_loop.py | 35 +++++++++++++++++-- tests/checkpointing/test_model_checkpoint.py | 35 ++++--------------- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 - 10 files changed, 39 insertions(+), 99 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index ffb26f38ca821..db507fa991446 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,10 +109,6 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the epoch ends.""" pass - def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: - """Called when at the very end of train epoch.""" - pass - def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 0de8ff6f0b505..4448de8e4834b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,21 +143,6 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - from pytorch_lightning.trainer.states import TrainerState - if ( - trainer.state != TrainerState.FITTING or trainer.sanity_checking - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we run early stopping - # at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self._run_early_stopping_check(trainer) - def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 2a56e1c8ac6e0..58324e363cd37 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,7 +53,6 @@ def __init__( on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, - on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -156,5 +155,3 @@ def __init__( self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad - if on_train_epoch_final_end is not None: - self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9436720e3819b..2a0c108ba7603 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,37 +238,6 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None: return self.save_checkpoint(trainer) - def on_train_epoch_final_end(self, trainer, pl_module): - """ - at the end of each training epoch, checkpoint only when validation is skipped or disabled - """ - print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) - if ( - self._should_skip_saving_checkpoint(trainer) - or not trainer.checkpoint_connector.has_trained - ): - return - # if validation is disabled or should skip, we checkpoint at end of the training epoch - if ( - trainer.disable_validation - or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) - ): - self.save_checkpoint(trainer) - - def on_train_end(self, trainer, *args, **kwargs) -> None: - """ - checkpoints can be saved at the end of the trianing - """ - trainer.global_step -= 1 - if ( - not self._should_skip_saving_checkpoint(trainer) - and trainer.checkpoint_connector.has_trained - ): - if self.save_last and self.verbose: - rank_zero_info("Saving latest checkpoint...") - self.save_checkpoint(trainer) - trainer.global_step += 1 - def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index c53c21ad04bc3..8823d48a7817e 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,13 +92,6 @@ def on_train_epoch_end(self, outputs: List[Any]): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) - def on_train_epoch_final_end(self) -> None: - """ - Called when at the very end of train epoch. - """ - for callback in self.callbacks: - callback.on_train_epoch_final_end(self, self.lightning_module) - def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index e7884124df314..534dad5199e9b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,11 +100,6 @@ def _on_train_epoch_end_log(): """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod - def _on_train_epoch_final_end_log(): - """Called when at the very end of train epoch.""" - return {"on_step": [False], "on_epoch": [False, True]} - @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1d498a0a9ff6c..c3ba34ca66d2d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,6 +121,12 @@ def on_train_end(self): return self._teardown_already_run = True + # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates + # when a checkpoint was saved at the last step + self.trainer.global_step -= 1 + self.check_checkpoint_callback(should_update=True, is_last=True) + self.trainer.global_step += 1 + # hook self.trainer.call_hook("on_train_end") @@ -139,6 +145,28 @@ def on_train_end(self): # reset bookkeeping self.trainer._running_stage = None + def check_checkpoint_callback(self, should_update, is_last=False): + # TODO bake this logic into the ModelCheckpoint callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = self.trainer.checkpoint_callbacks + + if is_last and any(cb.save_last and cb.verbose for cb in callbacks): + rank_zero_info("Saving latest checkpoint...") + + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + + def check_early_stopping_callback(self, should_update): + # TODO bake this logic into the EarlyStopping callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] + model = self.trainer.lightning_module + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -534,14 +562,15 @@ def run_training_epoch(self): if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + if should_train_only: + self.check_checkpoint_callback(True) + self.check_early_stopping_callback(True) + if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True - if should_train_only: - self.trainer.call_hook('on_train_epoch_final_end') - # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e0c295a843a21..75f25b90fa45f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,13 +609,7 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] - if period > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -637,14 +631,8 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -671,14 +659,8 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - # check that the correct ckpts were created - final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) - expected = ( - [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] - if every_n_val_epochs > 0 - else [] - ) - expected.append(final_epoch_ckpt) + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -834,15 +816,10 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, - val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - if verbose and save_last and not should_validate: - # no validation, hence checkpoint triggered at the end of each training epoch - assert caplog.messages.count('Saving latest checkpoint...') == False - else: - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index b2727177bcacd..3db0a8eaa065b 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,7 +300,6 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', - 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end', From 0ce7e056ac47436bc727f91f8eed335fc736696c Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:31:44 -0700 Subject: [PATCH 22/49] Revert "Revert "Update ddp_spawn.py"" This reverts commit 8222dc98ead37d961a52b7366070aa10f66d92d1. --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 941025b36c0ac..87d7fa5faecac 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,7 +21,6 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer -import numpy from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -79,7 +78,6 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") From fe9736d94bfcac3e084eec2d63e62351d3618175 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:31:49 -0700 Subject: [PATCH 23/49] Revert "Revert "Update test_metric_result_integration.py"" This reverts commit 6c095b2370a2afe9d24918a5798ce1ebffed7e0d. --- tests/core/test_metric_result_integration.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ffbe508816403..0b797dff0e42f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,8 +16,6 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric -import numpy -import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -98,7 +96,6 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) From c314ef6d30373c2c94fdeceef6ee7b9d961a48c9 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:31:56 -0700 Subject: [PATCH 24/49] Revert "Revert "Update test_results.py"" This reverts commit 250d0aaaa2e6c6a6a3407bc6c8b83c0fe2479c0b. --- tests/core/test_results.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 74c4a0c212564..f25ab0c40a6ea 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,8 +26,6 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf -import os -import numpy def _setup_ddp(rank, worldsize): @@ -52,7 +50,6 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 From c3feda03d7fbb25dcf1917e718209f98d0503327 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:32:05 -0700 Subject: [PATCH 25/49] Revert "Revert "Update utils.py"" This reverts commit 8651d54d79396eaaba16d7eb1e769a1e91d5702e. --- tests/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 4aac65257a504..f1f17d0624936 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From c759477a0a9462f812a880a8cee7c09b3f432520 Mon Sep 17 00:00:00 2001 From: shuyingsunshine21 <80445420+shuyingsunshine21@users.noreply.github.com> Date: Wed, 24 Mar 2021 10:32:13 -0700 Subject: [PATCH 26/49] Revert "Revert "Update test_all_gather_grad.py"" This reverts commit dcdcd29731061c919b15ab0b56669259817a81c4. --- tests/utilities/test_all_gather_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index a9f38a9e1d88c..f1860b10326e9 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 4e67db2a1fed55e6d6e3fa09766aaa55c7995ca1 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 24 Mar 2021 10:57:58 -0700 Subject: [PATCH 27/49] modify distributed environment to make test pass --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 3 ++- tests/core/test_metric_result_integration.py | 3 +++ tests/core/test_results.py | 3 +++ tests/helpers/utils.py | 2 +- tests/metrics/utils.py | 2 +- tests/utilities/test_all_gather_grad.py | 2 +- 6 files changed, 11 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 87d7fa5faecac..0b4b7680076a3 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -33,6 +33,7 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything +import numpy log = logging.getLogger(__name__) @@ -78,7 +79,7 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0b797dff0e42f..ffbe508816403 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,6 +16,8 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric +import numpy +import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -96,6 +98,7 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index f25ab0c40a6ea..74c4a0c212564 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,6 +26,8 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf +import os +import numpy def _setup_ddp(rank, worldsize): @@ -50,6 +52,7 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f5c1726a423bb..493d32d3fe454 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = "29501" def init_checkpoint_callback(logger): diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index f1f17d0624936..4aac65257a504 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index f1860b10326e9..a9f38a9e1d88c 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "8088" + os.environ["MASTER_PORT"] = "29501" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 1e41d5b78f1ca097dde011791be366be25648155 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Mon, 29 Mar 2021 17:02:19 -0700 Subject: [PATCH 28/49] add DDP communication hook --- .../plugins/training_type/ddp.py | 24 ++++++- .../plugins/training_type/ddp2.py | 4 ++ .../training_type/ddp_comm_hook_util.py | 68 +++++++++++++++++++ .../plugins/training_type/ddp_spawn.py | 22 +++++- pytorch_lightning/utilities/__init__.py | 2 +- pytorch_lightning/utilities/imports.py | 1 + tests/plugins/test_custom_plugin.py | 56 ++++++++++++++- 7 files changed, 172 insertions(+), 5 deletions(-) create mode 100644 pytorch_lightning/plugins/training_type/ddp_comm_hook_util.py diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 58e26e7db32d8..8ea5bb8015f53 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -29,10 +29,15 @@ from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin -from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, rank_zero_warn +from pytorch_lightning.utilities import ( + _HYDRA_AVAILABLE, + _TORCH_GREATER_EQUAL_1_7, + rank_zero_warn, +) from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.plugins.training_type.ddp_comm_hook_util import register_ddp_comm_hook if _HYDRA_AVAILABLE: from hydra.core.hydra_config import HydraConfig @@ -58,6 +63,9 @@ def __init__( num_nodes: int = 1, cluster_environment: ClusterEnvironment = None, sync_batchnorm: bool = False, + ddp_comm_state: Optional[object] = None, + ddp_comm_hook: Optional[callable] = None, + ddp_comm_wrapper: Optional[callable] = None, **kwargs: Union[Any, Dict[str, Any]], ) -> None: super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) @@ -70,6 +78,9 @@ def __init__( self.task_idx = None self.node_rank = 0 self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices + self.ddp_comm_state = ddp_comm_state + self.ddp_comm_hook = ddp_comm_hook + self.ddp_comm_wrapper = ddp_comm_wrapper @property def root_device(self): @@ -80,6 +91,10 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) return distributed_sampler_kwargs + @property + def is_single_process_single_device(self): + return True + def setup_environment(self): # start the other scripts if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": @@ -225,6 +240,13 @@ def configure_ddp(self): device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs, ) + register_ddp_comm_hook( + ddp_comm_state=self.ddp_comm_state, + ddp_comm_hook=self.ddp_comm_hook, + ddp_comm_wrapper=self.ddp_comm_wrapper, + model=self._model, + is_single_process_single_device=self.is_single_process_single_device, + ) def determine_ddp_device_ids(self): if self.root_device.type == "cpu": diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index a94bb5459bb1e..9690c468e86c2 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -59,6 +59,10 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=self.num_nodes, rank=self.global_rank) return distributed_sampler_kwargs + @property + def is_single_process_single_device(self): + return False + def set_world_ranks(self): self.local_rank = self.task_idx self.node_rank = self.cluster_environment.node_rank() diff --git a/pytorch_lightning/plugins/training_type/ddp_comm_hook_util.py b/pytorch_lightning/plugins/training_type/ddp_comm_hook_util.py new file mode 100644 index 0000000000000..6698c2e94bef1 --- /dev/null +++ b/pytorch_lightning/plugins/training_type/ddp_comm_hook_util.py @@ -0,0 +1,68 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +from typing import Optional + +from pytorch_lightning.utilities import ( + _TORCH_GREATER_EQUAL_1_7, + _TORCH_GREATER_EQUAL_1_9, + rank_zero_warn, + rank_zero_info, +) +from torch.distributed.algorithms.ddp_comm_hooks import ( + DDPCommHookType, + register_ddp_comm_hook, +) +from torch.nn.parallel.distributed import DistributedDataParallel + + +def register_ddp_comm_hook( + ddp_comm_state: Optional[object], + ddp_comm_hook: Optional[callable], + ddp_comm_wrapper: Optional[callable], + model: DistributedDataParallel, + is_single_process_single_device: bool, +): + # register DDP comm hook: https://pytorch.org/docs/master/ddp_comm_hooks.html + if ddp_comm_hook is None: + rank_zero_info("No DDP comm hook is provided, skipping.") + return + if not _TORCH_GREATER_EQUAL_1_7: + rank_zero_warn( + "Not registering DDP comm hook. " + "To use communication hooks, please use PyTorch version at least 1.7.0." + ) + return + if not is_single_process_single_device: + rank_zero_warn( + "Not registering DDP comm hook. " + "To use communication hooks, must be single process single device, see " + "https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/torch/nn/parallel/distributed.py#L1035" + ) + if ddp_comm_wrapper is not None: + if not _TORCH_GREATER_EQUAL_1_9: + rank_zero_warn( + "Not applying DDP comm wrapper. " + "To use communication wrapper, please use PyTorch version at least 1.9.0." + ) + else: + rank_zero_info( + "DDP comm wrapper is provided, apply ddp_comm_wrapper(ddp_comm_hook)." + ) + ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook) + + rank_zero_info("Registering DDP comm hook.") + model.register_comm_hook( + state=ddp_comm_state, + hook=ddp_comm_hook, + ) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 0b4b7680076a3..5f2376f1fb84c 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -33,7 +33,7 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything -import numpy +from pytorch_lightning.plugins.training_type.ddp_comm_hook_util import register_ddp_comm_hook log = logging.getLogger(__name__) @@ -48,6 +48,9 @@ def __init__( num_nodes: int = 1, cluster_environment: ClusterEnvironment = None, sync_batchnorm: bool = False, + ddp_comm_state: Optional[object] = None, + ddp_comm_hook: Optional[callable] = None, + ddp_wrapper_hook: Optional[callable] = None, **kwargs: Union[Any, Dict[str, Any]], ): super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) @@ -58,6 +61,9 @@ def __init__( self.num_processes = len(parallel_devices) self.node_rank = 0 self.mp_queue = None + self.ddp_comm_state = ddp_comm_state + self.ddp_comm_hook = ddp_comm_hook + self.ddp_wrapper_hook = ddp_wrapper_hook def __getstate__(self): """ Makes this plugin pickleable without destroying the queue in the current process. """ @@ -77,9 +83,13 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) return distributed_sampler_kwargs + @property + def is_single_process_single_device(self): + return True + def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" + # pass in a state q smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() @@ -190,6 +200,14 @@ def configure_ddp(self): **self._ddp_kwargs, ) + register_ddp_comm_hook( + ddp_comm_state=self.ddp_comm_state, + ddp_comm_hook=self.ddp_comm_hook, + ddp_wrapper_hook=self.ddp_wrapper_hook, + model=self._model, + is_single_process_single_device=self.is_single_process_single_device, + ) + def init_ddp_connection(self, global_rank: int, world_size: int) -> None: # TODO: this code is duplicated in DDP and DDPSpawn, make this a function os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 3e2ee3e51efe1..95e64e1bca33f 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities""" -import numpy from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 from pytorch_lightning.utilities.distributed import ( # noqa: F401 @@ -39,6 +38,7 @@ _RPC_AVAILABLE, _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_7, + _TORCH_GREATER_EQUAL_1_9, _TORCH_LOWER_EQUAL_1_4, _TORCH_QUANTIZE_AVAILABLE, _TORCHTEXT_AVAILABLE, diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index baeac9be57218..5ffc53ee927cb 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -69,6 +69,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0") _TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0") _TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0") +_TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0") _KINETO_AVAILABLE = torch.profiler.kineto_available() if _TORCH_GREATER_EQUAL_1_8 else False _APEX_AVAILABLE = _module_available("apex.amp") diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index 872b49ef48635..9827dc22702e0 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -15,10 +15,13 @@ from pytorch_lightning.plugins import DDPPlugin from tests.helpers import BoringModel from tests.helpers.runif import RunIf +from torch.distributed.algorithms.ddp_comm_hooks import ( + default_hooks as default, + powerSGD_hook as powerSGD, +) class CustomParallelPlugin(DDPPlugin): - def __init__(self, **kwargs): super().__init__(**kwargs) # Set to None so it will be overwritten by the accelerator connector. @@ -39,3 +42,54 @@ def test_sync_batchnorm_set(tmpdir): ) trainer.fit(model) assert plugin.sync_batchnorm is True + + +@RunIf(skip_windows=True, min_torch="1.7.0") +def test_ddp_fp16_compress_comm_hook(tmpdir): + """Test for DDP FP16 compress hook.""" + model = BoringModel() + plugin = DDPPlugin( + ddp_comm_hook=default.fp16_compress_hook, + ) + trainer = Trainer( + max_epochs=1, + plugins=[plugin], + default_root_dir=tmpdir, + sync_batchnorm=True, + ) + trainer.fit(model) + + +@RunIf(skip_windows=True, min_torch="1.7.0") +def test_ddp_sgd_graident_comm_hook(tmpdir): + """Test for DDP SGD hook.""" + model = BoringModel() + plugin = DDPPlugin( + ddp_comm_state=powerSGD.PowerSGDState(1), + ddp_comm_hook=powerSGD.powerSGD_hook, + ) + trainer = Trainer( + max_epochs=1, + plugins=[plugin], + default_root_dir=tmpdir, + sync_batchnorm=True, + ) + trainer.fit(model) + + +@RunIf(skip_windows=True, min_torch="1.9.0") +def test_ddp_fp16_compress_wrapper_comm_hook(tmpdir): + """Test for DDP fp16 compress wrapper for SGD hook.""" + model = BoringModel() + plugin = DDPPlugin( + ddp_comm_state=powerSGD.PowerSGDState(1), + ddp_comm_hook=powerSGD.powerSGD_hook, + ddp_comm_wrapper=default.fp16_compress_wrapper, + ) + trainer = Trainer( + max_epochs=1, + plugins=[plugin], + default_root_dir=tmpdir, + sync_batchnorm=True, + ) + trainer.fit(model) From 6833b8705091ead69c84e8d2878578ea713fbe99 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Mon, 29 Mar 2021 17:28:59 -0700 Subject: [PATCH 29/49] remove test related setting --- pytorch_lightning/plugins/training_type/ddp.py | 6 +----- pytorch_lightning/utilities/__init__.py | 1 + tests/core/test_metric_result_integration.py | 3 --- tests/core/test_results.py | 3 --- tests/helpers/utils.py | 2 +- tests/metrics/utils.py | 2 +- tests/utilities/test_all_gather_grad.py | 2 +- 7 files changed, 5 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 8ea5bb8015f53..9fff56ad154dd 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -29,11 +29,7 @@ from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin -from pytorch_lightning.utilities import ( - _HYDRA_AVAILABLE, - _TORCH_GREATER_EQUAL_1_7, - rank_zero_warn, -) +from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, rank_zero_warn from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 95e64e1bca33f..e69a9947d9f8b 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities""" +import numpy from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 from pytorch_lightning.utilities.distributed import ( # noqa: F401 diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ffbe508816403..0b797dff0e42f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,8 +16,6 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric -import numpy -import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -98,7 +96,6 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 74c4a0c212564..f25ab0c40a6ea 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -26,8 +26,6 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf -import os -import numpy def _setup_ddp(rank, worldsize): @@ -52,7 +50,6 @@ def _ddp_test_fn(rank, worldsize, result_cls: Result): def test_result_reduce_ddp(): """Make sure result logging works with DDP""" tutils.reset_seed() - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" tutils.set_random_master_port() worldsize = 2 diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 493d32d3fe454..f5c1726a423bb 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = "29501" + os.environ['MASTER_PORT'] = str(port) def init_checkpoint_callback(logger): diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 4aac65257a504..f1f17d0624936 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -26,7 +26,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index a9f38a9e1d88c..f1860b10326e9 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -13,7 +13,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" + os.environ["MASTER_PORT"] = "8088" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From f856d31636e166ed24bb337520cdcaa2ec05a7a3 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Mon, 29 Mar 2021 17:34:26 -0700 Subject: [PATCH 30/49] remove more test related setting --- tests/core/test_results.py | 1 + tests/utilities/test_all_gather_grad.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index f25ab0c40a6ea..9586344d8c0d9 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -30,6 +30,7 @@ def _setup_ddp(rank, worldsize): import os + os.environ["MASTER_ADDR"] = "localhost" # initialize the process group diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index f1860b10326e9..259f9f4c09871 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -44,7 +44,6 @@ def _test_all_gather_ddp(rank, world_size): @RunIf(skip_windows=True) def test_all_gather_ddp(): world_size = 3 - os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) From 14a0a1ba81c7993f68f5fbfb1983a64e6fa14168 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Mon, 29 Mar 2021 17:53:43 -0700 Subject: [PATCH 31/49] fix ddp comm hook util import issue --- pytorch_lightning/plugins/training_type/ddp_comm_hook_util.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_comm_hook_util.py b/pytorch_lightning/plugins/training_type/ddp_comm_hook_util.py index 6698c2e94bef1..2a6a45f00fd84 100644 --- a/pytorch_lightning/plugins/training_type/ddp_comm_hook_util.py +++ b/pytorch_lightning/plugins/training_type/ddp_comm_hook_util.py @@ -19,10 +19,6 @@ rank_zero_warn, rank_zero_info, ) -from torch.distributed.algorithms.ddp_comm_hooks import ( - DDPCommHookType, - register_ddp_comm_hook, -) from torch.nn.parallel.distributed import DistributedDataParallel From 8998469a8560d61e73334ce03433ff24311d1bff Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Mon, 29 Mar 2021 19:05:29 -0700 Subject: [PATCH 32/49] comments --- .../plugins/training_type/ddp.py | 32 +++++++++++------- .../training_type/ddp_comm_hook_util.py | 9 +---- .../plugins/training_type/ddp_spawn.py | 33 ++++++++++++------- 3 files changed, 43 insertions(+), 31 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 9fff56ad154dd..78a9fc5df3732 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -33,11 +33,12 @@ from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything -from pytorch_lightning.plugins.training_type.ddp_comm_hook_util import register_ddp_comm_hook if _HYDRA_AVAILABLE: from hydra.core.hydra_config import HydraConfig from hydra.utils import get_original_cwd, to_absolute_path +if _TORCH_GREATER_EQUAL_1_7: + from pytorch_lightning.plugins.training_type.ddp_comm_hook_util import register_ddp_comm_hook log = logging.getLogger(__name__) @@ -74,9 +75,9 @@ def __init__( self.task_idx = None self.node_rank = 0 self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices - self.ddp_comm_state = ddp_comm_state - self.ddp_comm_hook = ddp_comm_hook - self.ddp_comm_wrapper = ddp_comm_wrapper + self._ddp_comm_state = ddp_comm_state + self._ddp_comm_hook = ddp_comm_hook + self._ddp_comm_wrapper = ddp_comm_wrapper @property def root_device(self): @@ -229,6 +230,21 @@ def pre_configure_ddp(self): ) self._ddp_kwargs["find_unused_parameters"] = True + def register_model_hook(self) -> None: + if not _TORCH_GREATER_EQUAL_1_7: + rank_zero_warn( + "Not registering DDP comm hook. " + "To use communication hooks, please use PyTorch version at least 1.7.0." + ) + return + register_ddp_comm_hook( + ddp_comm_state=self._ddp_comm_state, + ddp_comm_hook=self._ddp_comm_hook, + ddp_comm_wrapper=self._ddp_comm_wrapper, + model=self._model, + is_single_process_single_device=self.is_single_process_single_device, + ) + def configure_ddp(self): self.pre_configure_ddp() self._model = DistributedDataParallel( @@ -236,13 +252,7 @@ def configure_ddp(self): device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs, ) - register_ddp_comm_hook( - ddp_comm_state=self.ddp_comm_state, - ddp_comm_hook=self.ddp_comm_hook, - ddp_comm_wrapper=self.ddp_comm_wrapper, - model=self._model, - is_single_process_single_device=self.is_single_process_single_device, - ) + self.register_model_hook() def determine_ddp_device_ids(self): if self.root_device.type == "cpu": diff --git a/pytorch_lightning/plugins/training_type/ddp_comm_hook_util.py b/pytorch_lightning/plugins/training_type/ddp_comm_hook_util.py index 2a6a45f00fd84..5acdedb992e7f 100644 --- a/pytorch_lightning/plugins/training_type/ddp_comm_hook_util.py +++ b/pytorch_lightning/plugins/training_type/ddp_comm_hook_util.py @@ -14,7 +14,6 @@ from typing import Optional from pytorch_lightning.utilities import ( - _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_9, rank_zero_warn, rank_zero_info, @@ -28,17 +27,11 @@ def register_ddp_comm_hook( ddp_comm_wrapper: Optional[callable], model: DistributedDataParallel, is_single_process_single_device: bool, -): +) -> None: # register DDP comm hook: https://pytorch.org/docs/master/ddp_comm_hooks.html if ddp_comm_hook is None: rank_zero_info("No DDP comm hook is provided, skipping.") return - if not _TORCH_GREATER_EQUAL_1_7: - rank_zero_warn( - "Not registering DDP comm hook. " - "To use communication hooks, please use PyTorch version at least 1.7.0." - ) - return if not is_single_process_single_device: rank_zero_warn( "Not registering DDP comm hook. " diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 5f2376f1fb84c..daf9c86ddb6a0 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -33,7 +33,8 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything -from pytorch_lightning.plugins.training_type.ddp_comm_hook_util import register_ddp_comm_hook +if _TORCH_GREATER_EQUAL_1_7: + from pytorch_lightning.plugins.training_type.ddp_comm_hook_util import register_ddp_comm_hook log = logging.getLogger(__name__) @@ -61,9 +62,9 @@ def __init__( self.num_processes = len(parallel_devices) self.node_rank = 0 self.mp_queue = None - self.ddp_comm_state = ddp_comm_state - self.ddp_comm_hook = ddp_comm_hook - self.ddp_wrapper_hook = ddp_wrapper_hook + self._ddp_comm_state = ddp_comm_state + self._ddp_comm_hook = ddp_comm_hook + self._ddp_wrapper_hook = ddp_wrapper_hook def __getstate__(self): """ Makes this plugin pickleable without destroying the queue in the current process. """ @@ -192,6 +193,21 @@ def pre_configure_ddp(self): ) self._ddp_kwargs["find_unused_parameters"] = True + def register_model_hook(self) -> None: + if not _TORCH_GREATER_EQUAL_1_7: + rank_zero_warn( + "Not registering DDP comm hook. " + "To use communication hooks, please use PyTorch version at least 1.7.0." + ) + return + register_ddp_comm_hook( + ddp_comm_state=self._ddp_comm_state, + ddp_comm_hook=self._ddp_comm_hook, + ddp_comm_wrapper=self._ddp_comm_wrapper, + model=self._model, + is_single_process_single_device=self.is_single_process_single_device, + ) + def configure_ddp(self): self.pre_configure_ddp() self._model = DistributedDataParallel( @@ -199,14 +215,7 @@ def configure_ddp(self): device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs, ) - - register_ddp_comm_hook( - ddp_comm_state=self.ddp_comm_state, - ddp_comm_hook=self.ddp_comm_hook, - ddp_wrapper_hook=self.ddp_wrapper_hook, - model=self._model, - is_single_process_single_device=self.is_single_process_single_device, - ) + self.register_model_hook() def init_ddp_connection(self, global_rank: int, world_size: int) -> None: # TODO: this code is duplicated in DDP and DDPSpawn, make this a function From a17947b5fc116ff0b901b1384e294ff4df65fed8 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Mon, 29 Mar 2021 19:58:13 -0700 Subject: [PATCH 33/49] one more fix for test_custom_plugin --- tests/plugins/test_custom_plugin.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index 9827dc22702e0..d572a54c665d1 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -15,10 +15,12 @@ from pytorch_lightning.plugins import DDPPlugin from tests.helpers import BoringModel from tests.helpers.runif import RunIf -from torch.distributed.algorithms.ddp_comm_hooks import ( - default_hooks as default, - powerSGD_hook as powerSGD, -) +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 +if _TORCH_GREATER_EQUAL_1_7: + from torch.distributed.algorithms.ddp_comm_hooks import ( + default_hooks as default, + powerSGD_hook as powerSGD, + ) class CustomParallelPlugin(DDPPlugin): From 91a945a496dbcea6c206fc9614beacc45dff6120 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Mon, 29 Mar 2021 21:57:18 -0700 Subject: [PATCH 34/49] fix ddp spwan --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index daf9c86ddb6a0..9307b6423b0df 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -51,7 +51,7 @@ def __init__( sync_batchnorm: bool = False, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, - ddp_wrapper_hook: Optional[callable] = None, + ddp_comm_wrapper: Optional[callable] = None, **kwargs: Union[Any, Dict[str, Any]], ): super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) @@ -64,7 +64,7 @@ def __init__( self.mp_queue = None self._ddp_comm_state = ddp_comm_state self._ddp_comm_hook = ddp_comm_hook - self._ddp_wrapper_hook = ddp_wrapper_hook + self._ddp_comm_wrapper = ddp_comm_wrapper def __getstate__(self): """ Makes this plugin pickleable without destroying the queue in the current process. """ From 78c6925dc6e9c3dee70049fae8c268618e8361af Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Mon, 29 Mar 2021 23:16:04 -0700 Subject: [PATCH 35/49] fix sgd --- tests/plugins/test_custom_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index d572a54c665d1..a77df6d6471b2 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -67,7 +67,7 @@ def test_ddp_sgd_graident_comm_hook(tmpdir): """Test for DDP SGD hook.""" model = BoringModel() plugin = DDPPlugin( - ddp_comm_state=powerSGD.PowerSGDState(1), + ddp_comm_state=powerSGD.PowerSGDState(None), ddp_comm_hook=powerSGD.powerSGD_hook, ) trainer = Trainer( @@ -84,7 +84,7 @@ def test_ddp_fp16_compress_wrapper_comm_hook(tmpdir): """Test for DDP fp16 compress wrapper for SGD hook.""" model = BoringModel() plugin = DDPPlugin( - ddp_comm_state=powerSGD.PowerSGDState(1), + ddp_comm_state=powerSGD.PowerSGDState(None), ddp_comm_hook=powerSGD.powerSGD_hook, ddp_comm_wrapper=default.fp16_compress_wrapper, ) From 443f223a0431bd4266fb22c7bf8664272d52cd7d Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 30 Mar 2021 14:42:47 -0700 Subject: [PATCH 36/49] address comments and add tests --- .../plugins/training_type/ddp.py | 27 ++-- .../plugins/training_type/ddp2.py | 2 +- .../training_type/ddp_comm_hook_util.py | 57 --------- .../plugins/training_type/ddp_spawn.py | 30 ++--- pytorch_lightning/utilities/distributed.py | 115 ++++++++++++++++++ tests/plugins/test_custom_plugin.py | 57 --------- .../plugins/test_ddp_plugin_with_comm_hook.py | 70 +++++++++++ 7 files changed, 209 insertions(+), 149 deletions(-) delete mode 100644 pytorch_lightning/plugins/training_type/ddp_comm_hook_util.py create mode 100644 tests/plugins/test_ddp_plugin_with_comm_hook.py diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 78a9fc5df3732..64b7748e12339 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -38,7 +38,7 @@ from hydra.core.hydra_config import HydraConfig from hydra.utils import get_original_cwd, to_absolute_path if _TORCH_GREATER_EQUAL_1_7: - from pytorch_lightning.plugins.training_type.ddp_comm_hook_util import register_ddp_comm_hook + from pytorch_lightning.utilities.distributed import register_ddp_comm_hook log = logging.getLogger(__name__) @@ -89,7 +89,7 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs @property - def is_single_process_single_device(self): + def _is_single_process_single_device(self) -> bool: return True def setup_environment(self): @@ -230,20 +230,15 @@ def pre_configure_ddp(self): ) self._ddp_kwargs["find_unused_parameters"] = True - def register_model_hook(self) -> None: - if not _TORCH_GREATER_EQUAL_1_7: - rank_zero_warn( - "Not registering DDP comm hook. " - "To use communication hooks, please use PyTorch version at least 1.7.0." + def _register_ddp_hooks(self) -> None: + if _TORCH_GREATER_EQUAL_1_7: + register_ddp_comm_hook( + model=self._model, + is_single_process_single_device=self._is_single_process_single_device, + ddp_comm_state=self._ddp_comm_state, + ddp_comm_hook=self._ddp_comm_hook, + ddp_comm_wrapper=self._ddp_comm_wrapper, ) - return - register_ddp_comm_hook( - ddp_comm_state=self._ddp_comm_state, - ddp_comm_hook=self._ddp_comm_hook, - ddp_comm_wrapper=self._ddp_comm_wrapper, - model=self._model, - is_single_process_single_device=self.is_single_process_single_device, - ) def configure_ddp(self): self.pre_configure_ddp() @@ -252,7 +247,7 @@ def configure_ddp(self): device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs, ) - self.register_model_hook() + self._register_ddp_hooks() def determine_ddp_device_ids(self): if self.root_device.type == "cpu": diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index 9690c468e86c2..f19fb05a16233 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -60,7 +60,7 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs @property - def is_single_process_single_device(self): + def _is_single_process_single_device(self) -> bool: return False def set_world_ranks(self): diff --git a/pytorch_lightning/plugins/training_type/ddp_comm_hook_util.py b/pytorch_lightning/plugins/training_type/ddp_comm_hook_util.py deleted file mode 100644 index 5acdedb992e7f..0000000000000 --- a/pytorch_lightning/plugins/training_type/ddp_comm_hook_util.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -from typing import Optional - -from pytorch_lightning.utilities import ( - _TORCH_GREATER_EQUAL_1_9, - rank_zero_warn, - rank_zero_info, -) -from torch.nn.parallel.distributed import DistributedDataParallel - - -def register_ddp_comm_hook( - ddp_comm_state: Optional[object], - ddp_comm_hook: Optional[callable], - ddp_comm_wrapper: Optional[callable], - model: DistributedDataParallel, - is_single_process_single_device: bool, -) -> None: - # register DDP comm hook: https://pytorch.org/docs/master/ddp_comm_hooks.html - if ddp_comm_hook is None: - rank_zero_info("No DDP comm hook is provided, skipping.") - return - if not is_single_process_single_device: - rank_zero_warn( - "Not registering DDP comm hook. " - "To use communication hooks, must be single process single device, see " - "https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/torch/nn/parallel/distributed.py#L1035" - ) - if ddp_comm_wrapper is not None: - if not _TORCH_GREATER_EQUAL_1_9: - rank_zero_warn( - "Not applying DDP comm wrapper. " - "To use communication wrapper, please use PyTorch version at least 1.9.0." - ) - else: - rank_zero_info( - "DDP comm wrapper is provided, apply ddp_comm_wrapper(ddp_comm_hook)." - ) - ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook) - - rank_zero_info("Registering DDP comm hook.") - model.register_comm_hook( - state=ddp_comm_state, - hook=ddp_comm_hook, - ) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 9307b6423b0df..195ed29f9cb55 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -34,7 +34,7 @@ from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything if _TORCH_GREATER_EQUAL_1_7: - from pytorch_lightning.plugins.training_type.ddp_comm_hook_util import register_ddp_comm_hook + from pytorch_lightning.utilities.distributed import register_ddp_comm_hook log = logging.getLogger(__name__) @@ -59,7 +59,7 @@ def __init__( self.sync_batchnorm = sync_batchnorm self._ddp_kwargs = kwargs self.dist = LightningDistributed() - self.num_processes = len(parallel_devices) + self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 self.node_rank = 0 self.mp_queue = None self._ddp_comm_state = ddp_comm_state @@ -85,12 +85,11 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs @property - def is_single_process_single_device(self): + def _is_single_process_single_device(self): return True def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - # pass in a state q smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() @@ -193,20 +192,15 @@ def pre_configure_ddp(self): ) self._ddp_kwargs["find_unused_parameters"] = True - def register_model_hook(self) -> None: - if not _TORCH_GREATER_EQUAL_1_7: - rank_zero_warn( - "Not registering DDP comm hook. " - "To use communication hooks, please use PyTorch version at least 1.7.0." + def _register_ddp_hooks(self) -> None: + if _TORCH_GREATER_EQUAL_1_7: + register_ddp_comm_hook( + model=self._model, + is_single_process_single_device=self._is_single_process_single_device, + ddp_comm_state=self._ddp_comm_state, + ddp_comm_hook=self._ddp_comm_hook, + ddp_comm_wrapper=self._ddp_comm_wrapper, ) - return - register_ddp_comm_hook( - ddp_comm_state=self._ddp_comm_state, - ddp_comm_hook=self._ddp_comm_hook, - ddp_comm_wrapper=self._ddp_comm_wrapper, - model=self._model, - is_single_process_single_device=self.is_single_process_single_device, - ) def configure_ddp(self): self.pre_configure_ddp() @@ -215,7 +209,7 @@ def configure_ddp(self): device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs, ) - self.register_model_hook() + self._register_ddp_hooks() def init_ddp_connection(self, global_rank: int, world_size: int) -> None: # TODO: this code is duplicated in DDP and DDPSpawn, make this a function diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 658f349a22215..8a8a690146e55 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -17,9 +17,15 @@ import warnings from functools import wraps from typing import Any, Optional, Union +from pytorch_lightning.utilities.imports import ( + _TORCH_GREATER_EQUAL_1_7, + _TORCH_GREATER_EQUAL_1_9, +) import torch +from torch.nn.parallel.distributed import DistributedDataParallel + log = logging.getLogger(__name__) if torch.distributed.is_available(): @@ -197,3 +203,112 @@ def all_gather_ddp_if_available( with torch.no_grad(): return AllGatherGrad.apply(tensor, group) return tensor + + +def register_ddp_comm_hook( + model: DistributedDataParallel, + is_single_process_single_device: bool, + ddp_comm_state: Optional[object] = None, + ddp_comm_hook: Optional[callable] = None, + ddp_comm_wrapper: Optional[callable] = None, +) -> None: + """ + Function to register communication hook for DDP model + https://pytorch.org/docs/master/ddp_comm_hooks.html + + Args: + model: DDP model + is_single_process_single_device: whether it is single-process single-device mode + ddp_comm_state: state is passed to the hook and can be used to maintain + and update any state information that users would like to + maintain as part of the training process. Examples: error + feedback in gradient compression, peers to communicate with + next in GossipGrad etc. + ddp_comm_hook: hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future: + + This function is called once the bucket is ready. The + hook can perform whatever processing is needed and return + a Future indicating completion of any async work (ex: allreduce). + If the hook doesn't perform any communication, it can also + just return a completed Future. The Future should hold the + new value of grad bucket's tensors. Once a bucket is ready, + c10d reducer would call this hook and use the tensors returned + by the Future and copy grads to individual parameters. + + ddp_comm_wrapper: communication hook wraper to support fp16_compress_hook() as wrapper, + which could be combined with ddp_comm_hook + + .. warning :: + DDP communication hook need pytorch version at least 1.7.0 + + .. warning :: + DDP communication hook does not support single-process multiple-device mode. + Gradbucket tensors should consist of only a single tensor. + + .. warning :: + DDP communication wrapper need pytorch version at least 1.9.0 + + Example:: + + from torch.distributed.algorithms.ddp_comm_hooks import ( + default_hooks as default, + powerSGD_hook as powerSGD, + ) + + # fp16_compress_hook for compress gradients + register_ddp_comm_hook( + model=ddp_model, + is_single_process_single_device=True, + ddp_comm_hook=default.fp16_compress_hook, + ) + + # powerSGD_hook + register_ddp_comm_hook( + model=ddp_model, + is_single_process_single_device=True, + ddp_comm_state=powerSGD.PowerSGDState(process_group=None), + ddp_comm_hook=powerSGD.powerSGD_hook, + ) + + # fp16_compress_wrapper combined with other communication hook + register_ddp_comm_hook( + model=ddp_model, + is_single_process_single_device=True, + ddp_comm_state=powerSGD.PowerSGDState(process_group=None), + ddp_comm_hook=powerSGD.powerSGD_hook, + ddp_comm_wrapper=default.fp16_compress_wrapper, + ) + """ + if not _TORCH_GREATER_EQUAL_1_7: + rank_zero_warn( + "Not registering DDP comm hook. " + "To use communication hooks, please use pytorch>=1.7.0." + ) + return + if ddp_comm_hook is None: + return + if not is_single_process_single_device: + rank_zero_warn( + "Not registering DDP comm hook. " + "To use communication hooks, must be single process single device, see " + "https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/" + "torch/nn/parallel/distributed.py#L1035" + ) + return + if ddp_comm_wrapper is not None: + if not _TORCH_GREATER_EQUAL_1_9: + rank_zero_warn( + "Not applying DDP comm wrapper. " + "To use communication wrapper, please use pytorch>=1.9.0." + ) + else: + rank_zero_info( + "DDP comm wrapper is provided, apply ddp_comm_wrapper(ddp_comm_hook)." + ) + ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook) + + rank_zero_debug("Registering DDP comm hook.") + model.register_comm_hook( + state=ddp_comm_state, + hook=ddp_comm_hook, + ) diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index a77df6d6471b2..6b04ac7b38708 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -15,12 +15,6 @@ from pytorch_lightning.plugins import DDPPlugin from tests.helpers import BoringModel from tests.helpers.runif import RunIf -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 -if _TORCH_GREATER_EQUAL_1_7: - from torch.distributed.algorithms.ddp_comm_hooks import ( - default_hooks as default, - powerSGD_hook as powerSGD, - ) class CustomParallelPlugin(DDPPlugin): @@ -44,54 +38,3 @@ def test_sync_batchnorm_set(tmpdir): ) trainer.fit(model) assert plugin.sync_batchnorm is True - - -@RunIf(skip_windows=True, min_torch="1.7.0") -def test_ddp_fp16_compress_comm_hook(tmpdir): - """Test for DDP FP16 compress hook.""" - model = BoringModel() - plugin = DDPPlugin( - ddp_comm_hook=default.fp16_compress_hook, - ) - trainer = Trainer( - max_epochs=1, - plugins=[plugin], - default_root_dir=tmpdir, - sync_batchnorm=True, - ) - trainer.fit(model) - - -@RunIf(skip_windows=True, min_torch="1.7.0") -def test_ddp_sgd_graident_comm_hook(tmpdir): - """Test for DDP SGD hook.""" - model = BoringModel() - plugin = DDPPlugin( - ddp_comm_state=powerSGD.PowerSGDState(None), - ddp_comm_hook=powerSGD.powerSGD_hook, - ) - trainer = Trainer( - max_epochs=1, - plugins=[plugin], - default_root_dir=tmpdir, - sync_batchnorm=True, - ) - trainer.fit(model) - - -@RunIf(skip_windows=True, min_torch="1.9.0") -def test_ddp_fp16_compress_wrapper_comm_hook(tmpdir): - """Test for DDP fp16 compress wrapper for SGD hook.""" - model = BoringModel() - plugin = DDPPlugin( - ddp_comm_state=powerSGD.PowerSGDState(None), - ddp_comm_hook=powerSGD.powerSGD_hook, - ddp_comm_wrapper=default.fp16_compress_wrapper, - ) - trainer = Trainer( - max_epochs=1, - plugins=[plugin], - default_root_dir=tmpdir, - sync_batchnorm=True, - ) - trainer.fit(model) diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py new file mode 100644 index 0000000000000..ed1adc48dae60 --- /dev/null +++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py @@ -0,0 +1,70 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning import Trainer +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.plugins import DDPPlugin +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 +if _TORCH_GREATER_EQUAL_1_7: + from torch.distributed.algorithms.ddp_comm_hooks import ( + default_hooks as default, + powerSGD_hook as powerSGD, + ) + + +class BoringHalfGrdadientPrecisionCheckModel(BoringModel): + + def on_after_backward(self): + super().on_after_backward() + for k, v in self.named_parameters(): + assert v.grad.dtype == torch.half + +@RunIf(skip_windows=True, min_torch="1.7.0", min_gpus=2) +def test_ddp_fp16_compress_comm_hook(tmpdir): + """Test for DDP FP16 compress hook.""" + model = BoringHalfGrdadientPrecisionCheckModel() + training_type_plugin = DDPPlugin( + ddp_comm_hook=default.fp16_compress_hook, + sync_batchnorm=True, + ) + trainer = Trainer( + max_epochs=1, + gpus=2, + plugins=[training_type_plugin], + default_root_dir=tmpdir, + sync_batchnorm=True, + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + +@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2) +def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): + """Test for DDP FP16 compress wrapper for SGD hook.""" + model = BoringHalfGrdadientPrecisionCheckModel() + training_type_plugin = DDPPlugin( + ddp_comm_state=powerSGD.PowerSGDState(process_group=None), + ddp_comm_hook=powerSGD.powerSGD_hook, + ddp_comm_wrapper=default.fp16_compress_wrapper, + sync_batchnorm=True, + ) + trainer = Trainer( + max_epochs=1, + gpus=2, + plugins=[training_type_plugin], + default_root_dir=tmpdir, + sync_batchnorm=True, + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" From f8d06035720c354becbdc4ab764108cb6f50f9e2 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 31 Mar 2021 11:47:57 -0700 Subject: [PATCH 37/49] 1. add is gpu checking 2. modify test a bit 3. formatting --- .../plugins/training_type/ddp.py | 5 +- .../plugins/training_type/ddp_spawn.py | 5 +- pytorch_lightning/utilities/distributed.py | 8 ++- .../plugins/test_ddp_plugin_with_comm_hook.py | 68 +++++++++++++++---- 4 files changed, 68 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 64b7748e12339..b4a55e22d3e46 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -231,7 +231,10 @@ def pre_configure_ddp(self): self._ddp_kwargs["find_unused_parameters"] = True def _register_ddp_hooks(self) -> None: - if _TORCH_GREATER_EQUAL_1_7: + # currently, DDP communication hooks only work with NCCL backend + # https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/ + # torch/nn/parallel/distributed.py#L1040 + if _TORCH_GREATER_EQUAL_1_7 and self.on_gpu: register_ddp_comm_hook( model=self._model, is_single_process_single_device=self._is_single_process_single_device, diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 195ed29f9cb55..4032c0d79c3ac 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -193,7 +193,10 @@ def pre_configure_ddp(self): self._ddp_kwargs["find_unused_parameters"] = True def _register_ddp_hooks(self) -> None: - if _TORCH_GREATER_EQUAL_1_7: + # currently, DDP communication hooks only work with NCCL backend + # https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/ + # torch/nn/parallel/distributed.py#L1040 + if _TORCH_GREATER_EQUAL_1_7 and self.on_gpu: register_ddp_comm_hook( model=self._model, is_single_process_single_device=self._is_single_process_single_device, diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 8a8a690146e55..4856759fd2f02 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -248,7 +248,7 @@ def register_ddp_comm_hook( .. warning :: DDP communication wrapper need pytorch version at least 1.9.0 - Example:: + Example: from torch.distributed.algorithms.ddp_comm_hooks import ( default_hooks as default, @@ -303,11 +303,13 @@ def register_ddp_comm_hook( ) else: rank_zero_info( - "DDP comm wrapper is provided, apply ddp_comm_wrapper(ddp_comm_hook)." + "DDP comm wrapper is provided, apply {}({}).".format( + ddp_comm_wrapper.__qualname__, ddp_comm_hook.__qualname__ + ) ) ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook) - rank_zero_debug("Registering DDP comm hook.") + rank_zero_debug("Registering DDP comm hook: {}.".format(ddp_comm_hook.__qualname__)) model.register_comm_hook( state=ddp_comm_state, hook=ddp_comm_hook, diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py index ed1adc48dae60..390503c34d5e3 100644 --- a/tests/plugins/test_ddp_plugin_with_comm_hook.py +++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch from pytorch_lightning import Trainer -from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.plugins import DDPPlugin +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 from tests.helpers import BoringModel from tests.helpers.runif import RunIf -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 + if _TORCH_GREATER_EQUAL_1_7: from torch.distributed.algorithms.ddp_comm_hooks import ( default_hooks as default, @@ -24,17 +26,10 @@ ) -class BoringHalfGrdadientPrecisionCheckModel(BoringModel): - - def on_after_backward(self): - super().on_after_backward() - for k, v in self.named_parameters(): - assert v.grad.dtype == torch.half - @RunIf(skip_windows=True, min_torch="1.7.0", min_gpus=2) def test_ddp_fp16_compress_comm_hook(tmpdir): """Test for DDP FP16 compress hook.""" - model = BoringHalfGrdadientPrecisionCheckModel() + model = BoringModel() training_type_plugin = DDPPlugin( ddp_comm_hook=default.fp16_compress_hook, sync_batchnorm=True, @@ -45,14 +40,51 @@ def test_ddp_fp16_compress_comm_hook(tmpdir): plugins=[training_type_plugin], default_root_dir=tmpdir, sync_batchnorm=True, + fast_dev_run=True, + ) + trainer.fit(model) + trainer_comm_hook = ( + trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook + ) + expected_comm_hook = default.fp16_compress_hook.__qualname__ + assert trainer_comm_hook == expected_comm_hook + assert ( + trainer.state == TrainerState.FINISHED + ), f"Training failed with {trainer.state}" + + +@RunIf(skip_windows=True, min_torch="1.7.0", min_gpus=2) +def test_ddp_sgd_comm_hook(tmpdir): + """Test for DDP FP16 compress hook.""" + model = BoringModel() + training_type_plugin = DDPPlugin( + ddp_comm_state=powerSGD.PowerSGDState(process_group=None), + ddp_comm_hook=powerSGD.powerSGD_hook, + sync_batchnorm=True, + ) + trainer = Trainer( + max_epochs=1, + gpus=2, + plugins=[training_type_plugin], + default_root_dir=tmpdir, + sync_batchnorm=True, + fast_dev_run=True, ) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer_comm_hook = ( + trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook + ) + expected_comm_hook = powerSGD.powerSGD_hook.__qualname__ + assert trainer_comm_hook == expected_comm_hook + assert ( + trainer.state == TrainerState.FINISHED + ), f"Training failed with {trainer.state}" + @RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2) def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): """Test for DDP FP16 compress wrapper for SGD hook.""" - model = BoringHalfGrdadientPrecisionCheckModel() + model = BoringModel() training_type_plugin = DDPPlugin( ddp_comm_state=powerSGD.PowerSGDState(process_group=None), ddp_comm_hook=powerSGD.powerSGD_hook, @@ -65,6 +97,16 @@ def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): plugins=[training_type_plugin], default_root_dir=tmpdir, sync_batchnorm=True, + fast_dev_run=True, ) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer_comm_hook = ( + trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook + ) + expected_comm_hook = default.fp16_compress_wrapper( + powerSGD.powerSGD_hook + ).__qualname__ + assert trainer_comm_hook == expected_comm_hook + assert ( + trainer.state == TrainerState.FINISHED + ), f"Training failed with {trainer.state}" From f06285f60eed27bd8d0f4e826cb5385febf01537 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 31 Mar 2021 12:09:03 -0700 Subject: [PATCH 38/49] formatting nit --- tests/plugins/test_ddp_plugin_with_comm_hook.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py index 390503c34d5e3..35721e1e3916c 100644 --- a/tests/plugins/test_ddp_plugin_with_comm_hook.py +++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch from pytorch_lightning import Trainer from pytorch_lightning.plugins import DDPPlugin from pytorch_lightning.trainer.states import TrainerState From b607ebd9407300b832b87d517ff35dc1144cca93 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 31 Mar 2021 13:47:13 -0700 Subject: [PATCH 39/49] fix conda 3.7 1.7 issue for no torch.distributed.algorithms module --- tests/plugins/test_ddp_plugin_with_comm_hook.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py index 35721e1e3916c..516a50e63f2c4 100644 --- a/tests/plugins/test_ddp_plugin_with_comm_hook.py +++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch from pytorch_lightning import Trainer from pytorch_lightning.plugins import DDPPlugin from pytorch_lightning.trainer.states import TrainerState @@ -18,7 +19,7 @@ from tests.helpers import BoringModel from tests.helpers.runif import RunIf -if _TORCH_GREATER_EQUAL_1_7: +if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_7: from torch.distributed.algorithms.ddp_comm_hooks import ( default_hooks as default, powerSGD_hook as powerSGD, From 6cc9dfabbdbb6f670170a71001f6497f9989daa1 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 31 Mar 2021 20:16:01 -0700 Subject: [PATCH 40/49] need at least 1.8.0 --- pytorch_lightning/plugins/training_type/ddp.py | 11 ++++++++--- pytorch_lightning/plugins/training_type/ddp_spawn.py | 4 ++-- pytorch_lightning/utilities/__init__.py | 1 + pytorch_lightning/utilities/distributed.py | 8 ++++---- tests/plugins/test_ddp_plugin_with_comm_hook.py | 8 ++++---- 5 files changed, 19 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index b4a55e22d3e46..c18b2f6167b80 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -29,7 +29,12 @@ from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin -from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, rank_zero_warn +from pytorch_lightning.utilities import ( + _HYDRA_AVAILABLE, + _TORCH_GREATER_EQUAL_1_7, + _TORCH_GREATER_EQUAL_1_8, + rank_zero_warn, +) from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything @@ -37,7 +42,7 @@ if _HYDRA_AVAILABLE: from hydra.core.hydra_config import HydraConfig from hydra.utils import get_original_cwd, to_absolute_path -if _TORCH_GREATER_EQUAL_1_7: +if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook log = logging.getLogger(__name__) @@ -234,7 +239,7 @@ def _register_ddp_hooks(self) -> None: # currently, DDP communication hooks only work with NCCL backend # https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/ # torch/nn/parallel/distributed.py#L1040 - if _TORCH_GREATER_EQUAL_1_7 and self.on_gpu: + if _TORCH_GREATER_EQUAL_1_8 and self.on_gpu: register_ddp_comm_hook( model=self._model, is_single_process_single_device=self._is_single_process_single_device, diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 4032c0d79c3ac..66587ab730ddf 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -33,7 +33,7 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything -if _TORCH_GREATER_EQUAL_1_7: +if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook log = logging.getLogger(__name__) @@ -196,7 +196,7 @@ def _register_ddp_hooks(self) -> None: # currently, DDP communication hooks only work with NCCL backend # https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/ # torch/nn/parallel/distributed.py#L1040 - if _TORCH_GREATER_EQUAL_1_7 and self.on_gpu: + if _TORCH_GREATER_EQUAL_1_8 and self.on_gpu: register_ddp_comm_hook( model=self._model, is_single_process_single_device=self._is_single_process_single_device, diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index e69a9947d9f8b..eeee77a7a8960 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -39,6 +39,7 @@ _RPC_AVAILABLE, _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_7, + _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TORCH_LOWER_EQUAL_1_4, _TORCH_QUANTIZE_AVAILABLE, diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 4856759fd2f02..ad19cd90f8c29 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -18,7 +18,7 @@ from functools import wraps from typing import Any, Optional, Union from pytorch_lightning.utilities.imports import ( - _TORCH_GREATER_EQUAL_1_7, + _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, ) @@ -239,7 +239,7 @@ def register_ddp_comm_hook( which could be combined with ddp_comm_hook .. warning :: - DDP communication hook need pytorch version at least 1.7.0 + DDP communication hook need pytorch version at least 1.8.0 .. warning :: DDP communication hook does not support single-process multiple-device mode. @@ -279,10 +279,10 @@ def register_ddp_comm_hook( ddp_comm_wrapper=default.fp16_compress_wrapper, ) """ - if not _TORCH_GREATER_EQUAL_1_7: + if not _TORCH_GREATER_EQUAL_1_8: rank_zero_warn( "Not registering DDP comm hook. " - "To use communication hooks, please use pytorch>=1.7.0." + "To use communication hooks, please use pytorch>=1.8.0." ) return if ddp_comm_hook is None: diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py index 516a50e63f2c4..0b9ef9fd05742 100644 --- a/tests/plugins/test_ddp_plugin_with_comm_hook.py +++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py @@ -15,18 +15,18 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import DDPPlugin from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 from tests.helpers import BoringModel from tests.helpers.runif import RunIf -if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_7: +if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_8: from torch.distributed.algorithms.ddp_comm_hooks import ( default_hooks as default, powerSGD_hook as powerSGD, ) -@RunIf(skip_windows=True, min_torch="1.7.0", min_gpus=2) +@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2) def test_ddp_fp16_compress_comm_hook(tmpdir): """Test for DDP FP16 compress hook.""" model = BoringModel() @@ -53,7 +53,7 @@ def test_ddp_fp16_compress_comm_hook(tmpdir): ), f"Training failed with {trainer.state}" -@RunIf(skip_windows=True, min_torch="1.7.0", min_gpus=2) +@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2) def test_ddp_sgd_comm_hook(tmpdir): """Test for DDP FP16 compress hook.""" model = BoringModel() From b12a16bff9e63bf9b49deb4c1df416958434e866 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 31 Mar 2021 20:19:55 -0700 Subject: [PATCH 41/49] minor fix --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 66587ab730ddf..5d6beb37824e4 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -28,7 +28,7 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available From 25ccb820270fe628c246762d9d8aac423266cb61 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 31 Mar 2021 21:10:27 -0700 Subject: [PATCH 42/49] modify changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fb1aa7481d7e0..ba3b8816559ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595)) +- Added support for DDP communication hooks ([#6727] (https://github.com/PyTorchLightning/pytorch-lightning/issues/6727)) ### Changed From 35d49bca1a3efe610e0537093499ec72ac032d04 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 31 Mar 2021 21:14:46 -0700 Subject: [PATCH 43/49] changelog should link to PR number instead of issue number --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ba3b8816559ce..c11a3fdfa6ddb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,7 +63,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595)) -- Added support for DDP communication hooks ([#6727] (https://github.com/PyTorchLightning/pytorch-lightning/issues/6727)) +- Added support for DDP communication hooks ([#6736] (https://github.com/PyTorchLightning/pytorch-lightning/issues/6736)) ### Changed From dc5c55ce76c2dee1062bf32b9179c9aa2a9d074a Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Thu, 1 Apr 2021 11:09:20 -0700 Subject: [PATCH 44/49] refine a bit on doc for register_ddp_comm_hook function, like ddp_comm_wrapper explanation and add hyperparameter for power sgd states in example usge --- pytorch_lightning/utilities/distributed.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index ad19cd90f8c29..02f1daece16f2 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -235,8 +235,9 @@ def register_ddp_comm_hook( c10d reducer would call this hook and use the tensors returned by the Future and copy grads to individual parameters. - ddp_comm_wrapper: communication hook wraper to support fp16_compress_hook() as wrapper, - which could be combined with ddp_comm_hook + ddp_comm_wrapper: communication hook wraper to support a communication hook such + as FP16 compression as wrapper, which could be combined with + ddp_comm_hook .. warning :: DDP communication hook need pytorch version at least 1.8.0 @@ -266,7 +267,11 @@ def register_ddp_comm_hook( register_ddp_comm_hook( model=ddp_model, is_single_process_single_device=True, - ddp_comm_state=powerSGD.PowerSGDState(process_group=None), + ddp_comm_state=powerSGD.PowerSGDState( + process_group=None, + matrix_approximation_rank=1, + start_powerSGD_iter=5000, + ), ddp_comm_hook=powerSGD.powerSGD_hook, ) @@ -274,7 +279,11 @@ def register_ddp_comm_hook( register_ddp_comm_hook( model=ddp_model, is_single_process_single_device=True, - ddp_comm_state=powerSGD.PowerSGDState(process_group=None), + ddp_comm_state=powerSGD.PowerSGDState( + process_group=None, + matrix_approximation_rank=1, + start_powerSGD_iter=5000, + ), ddp_comm_hook=powerSGD.powerSGD_hook, ddp_comm_wrapper=default.fp16_compress_wrapper, ) From fb184b288d9ed100b38d963a8355fae13de9ac3b Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Thu, 1 Apr 2021 15:52:11 -0700 Subject: [PATCH 45/49] move single device checking before call register_ddp_comm_hook --- pytorch_lightning/plugins/training_type/ddp.py | 9 ++++++--- .../plugins/training_type/ddp_spawn.py | 9 ++++++--- pytorch_lightning/utilities/distributed.py | 17 ----------------- 3 files changed, 12 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index c18b2f6167b80..9a87179e79282 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -236,13 +236,16 @@ def pre_configure_ddp(self): self._ddp_kwargs["find_unused_parameters"] = True def _register_ddp_hooks(self) -> None: - # currently, DDP communication hooks only work with NCCL backend + # currently, DDP communication hooks only work with NCCL backend and singlge process single device mode # https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/ # torch/nn/parallel/distributed.py#L1040 - if _TORCH_GREATER_EQUAL_1_8 and self.on_gpu: + if ( + _TORCH_GREATER_EQUAL_1_8 + and self.on_gpu + and self._is_single_process_single_device + ): register_ddp_comm_hook( model=self._model, - is_single_process_single_device=self._is_single_process_single_device, ddp_comm_state=self._ddp_comm_state, ddp_comm_hook=self._ddp_comm_hook, ddp_comm_wrapper=self._ddp_comm_wrapper, diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 5d6beb37824e4..b22dbe76eab75 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -193,13 +193,16 @@ def pre_configure_ddp(self): self._ddp_kwargs["find_unused_parameters"] = True def _register_ddp_hooks(self) -> None: - # currently, DDP communication hooks only work with NCCL backend + # currently, DDP communication hooks only work with NCCL backend and singlge process single device mode # https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/ # torch/nn/parallel/distributed.py#L1040 - if _TORCH_GREATER_EQUAL_1_8 and self.on_gpu: + if ( + _TORCH_GREATER_EQUAL_1_8 + and self.on_gpu + and self._is_single_process_single_device + ): register_ddp_comm_hook( model=self._model, - is_single_process_single_device=self._is_single_process_single_device, ddp_comm_state=self._ddp_comm_state, ddp_comm_hook=self._ddp_comm_hook, ddp_comm_wrapper=self._ddp_comm_wrapper, diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 02f1daece16f2..968de44bfff73 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -207,7 +207,6 @@ def all_gather_ddp_if_available( def register_ddp_comm_hook( model: DistributedDataParallel, - is_single_process_single_device: bool, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, @@ -218,7 +217,6 @@ def register_ddp_comm_hook( Args: model: DDP model - is_single_process_single_device: whether it is single-process single-device mode ddp_comm_state: state is passed to the hook and can be used to maintain and update any state information that users would like to maintain as part of the training process. Examples: error @@ -242,10 +240,6 @@ def register_ddp_comm_hook( .. warning :: DDP communication hook need pytorch version at least 1.8.0 - .. warning :: - DDP communication hook does not support single-process multiple-device mode. - Gradbucket tensors should consist of only a single tensor. - .. warning :: DDP communication wrapper need pytorch version at least 1.9.0 @@ -259,14 +253,12 @@ def register_ddp_comm_hook( # fp16_compress_hook for compress gradients register_ddp_comm_hook( model=ddp_model, - is_single_process_single_device=True, ddp_comm_hook=default.fp16_compress_hook, ) # powerSGD_hook register_ddp_comm_hook( model=ddp_model, - is_single_process_single_device=True, ddp_comm_state=powerSGD.PowerSGDState( process_group=None, matrix_approximation_rank=1, @@ -278,7 +270,6 @@ def register_ddp_comm_hook( # fp16_compress_wrapper combined with other communication hook register_ddp_comm_hook( model=ddp_model, - is_single_process_single_device=True, ddp_comm_state=powerSGD.PowerSGDState( process_group=None, matrix_approximation_rank=1, @@ -296,14 +287,6 @@ def register_ddp_comm_hook( return if ddp_comm_hook is None: return - if not is_single_process_single_device: - rank_zero_warn( - "Not registering DDP comm hook. " - "To use communication hooks, must be single process single device, see " - "https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/" - "torch/nn/parallel/distributed.py#L1035" - ) - return if ddp_comm_wrapper is not None: if not _TORCH_GREATER_EQUAL_1_9: rank_zero_warn( From bf44378b624b5151f000722d8d35d85378d1d51d Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Fri, 2 Apr 2021 14:51:53 -0700 Subject: [PATCH 46/49] formatting --- pytorch_lightning/utilities/distributed.py | 53 +++++++++++----------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 968de44bfff73..200f435a9bc94 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -216,32 +216,35 @@ def register_ddp_comm_hook( https://pytorch.org/docs/master/ddp_comm_hooks.html Args: - model: DDP model - ddp_comm_state: state is passed to the hook and can be used to maintain - and update any state information that users would like to - maintain as part of the training process. Examples: error - feedback in gradient compression, peers to communicate with - next in GossipGrad etc. - ddp_comm_hook: hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future: - - This function is called once the bucket is ready. The - hook can perform whatever processing is needed and return - a Future indicating completion of any async work (ex: allreduce). - If the hook doesn't perform any communication, it can also - just return a completed Future. The Future should hold the - new value of grad bucket's tensors. Once a bucket is ready, - c10d reducer would call this hook and use the tensors returned - by the Future and copy grads to individual parameters. - - ddp_comm_wrapper: communication hook wraper to support a communication hook such - as FP16 compression as wrapper, which could be combined with - ddp_comm_hook + model: + DDP model + ddp_comm_state: + state is passed to the hook and can be used to maintain + and update any state information that users would like to + maintain as part of the training process. Examples: error + feedback in gradient compression, peers to communicate with + next in GossipGrad etc. + ddp_comm_hook: + hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future + + This callable function is called once the bucket is ready. The + hook can perform whatever processing is needed and return + a Future indicating completion of any async work (ex: allreduce). + If the hook doesn't perform any communication, it can also + just return a completed Future. The Future should hold the + new value of grad bucket's tensors. Once a bucket is ready, + c10d reducer would call this hook and use the tensors returned + by the Future and copy grads to individual parameters. + ddp_comm_wrapper: + communication hook wraper to support a communication hook such + as FP16 compression as wrapper, which could be combined with + ddp_comm_hook .. warning :: - DDP communication hook need pytorch version at least 1.8.0 + DDP communication hook needs pytorch version at least 1.8.0 .. warning :: - DDP communication wrapper need pytorch version at least 1.9.0 + DDP communication wrapper needs pytorch version at least 1.9.0 Example: @@ -295,13 +298,11 @@ def register_ddp_comm_hook( ) else: rank_zero_info( - "DDP comm wrapper is provided, apply {}({}).".format( - ddp_comm_wrapper.__qualname__, ddp_comm_hook.__qualname__ - ) + f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})." ) ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook) - rank_zero_debug("Registering DDP comm hook: {}.".format(ddp_comm_hook.__qualname__)) + rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.") model.register_comm_hook( state=ddp_comm_state, hook=ddp_comm_hook, From d529985a0df82b6b7a5845f90176436a9455f0e6 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Mon, 5 Apr 2021 11:02:13 -0700 Subject: [PATCH 47/49] comments --- .../plugins/training_type/ddp.py | 2 +- .../plugins/training_type/ddp_spawn.py | 2 +- .../plugins/test_ddp_plugin_with_comm_hook.py | 35 ++++++++++++++++--- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 9a87179e79282..44d3b9fbdd861 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -236,7 +236,7 @@ def pre_configure_ddp(self): self._ddp_kwargs["find_unused_parameters"] = True def _register_ddp_hooks(self) -> None: - # currently, DDP communication hooks only work with NCCL backend and singlge process single device mode + # currently, DDP communication hooks only work with NCCL backend and SPSD (singlge process single device) mode # https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/ # torch/nn/parallel/distributed.py#L1040 if ( diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index b22dbe76eab75..11c2dffb397a0 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -193,7 +193,7 @@ def pre_configure_ddp(self): self._ddp_kwargs["find_unused_parameters"] = True def _register_ddp_hooks(self) -> None: - # currently, DDP communication hooks only work with NCCL backend and singlge process single device mode + # currently, DDP communication hooks only work with NCCL backend and SPSD (singlge process single device) mode # https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/ # torch/nn/parallel/distributed.py#L1040 if ( diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py index 0b9ef9fd05742..06c8069441706 100644 --- a/tests/plugins/test_ddp_plugin_with_comm_hook.py +++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py @@ -13,7 +13,7 @@ # limitations under the License. import torch from pytorch_lightning import Trainer -from pytorch_lightning.plugins import DDPPlugin +from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 from tests.helpers import BoringModel @@ -26,7 +26,7 @@ ) -@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2) +@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2, special=True) def test_ddp_fp16_compress_comm_hook(tmpdir): """Test for DDP FP16 compress hook.""" model = BoringModel() @@ -53,7 +53,7 @@ def test_ddp_fp16_compress_comm_hook(tmpdir): ), f"Training failed with {trainer.state}" -@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2) +@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2, special=True) def test_ddp_sgd_comm_hook(tmpdir): """Test for DDP FP16 compress hook.""" model = BoringModel() @@ -81,7 +81,7 @@ def test_ddp_sgd_comm_hook(tmpdir): ), f"Training failed with {trainer.state}" -@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2) +@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True) def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): """Test for DDP FP16 compress wrapper for SGD hook.""" model = BoringModel() @@ -110,3 +110,30 @@ def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): assert ( trainer.state == TrainerState.FINISHED ), f"Training failed with {trainer.state}" + + +@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2, special=True) +def test_ddp_spawn_fp16_compress_comm_hook(tmpdir): + """Test for DDP Spawn FP16 compress hook.""" + model = BoringModel() + training_type_plugin = DDPSpawnPlugin( + ddp_comm_hook=default.fp16_compress_hook, + sync_batchnorm=True, + ) + trainer = Trainer( + max_epochs=1, + gpus=2, + plugins=[training_type_plugin], + default_root_dir=tmpdir, + sync_batchnorm=True, + fast_dev_run=True, + ) + trainer.fit(model) + trainer_comm_hook = ( + trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook + ) + expected_comm_hook = default.fp16_compress_hook.__qualname__ + assert trainer_comm_hook == expected_comm_hook + assert ( + trainer.state == TrainerState.FINISHED + ), f"Training failed with {trainer.state}" From b8105be9e4087f49e120514a01afaebfeda610ee Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Mon, 5 Apr 2021 12:54:55 -0700 Subject: [PATCH 48/49] typo --- pytorch_lightning/plugins/training_type/ddp.py | 2 +- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 +- tests/plugins/test_ddp_plugin_with_comm_hook.py | 5 ----- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 44d3b9fbdd861..956cb424d7a11 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -236,7 +236,7 @@ def pre_configure_ddp(self): self._ddp_kwargs["find_unused_parameters"] = True def _register_ddp_hooks(self) -> None: - # currently, DDP communication hooks only work with NCCL backend and SPSD (singlge process single device) mode + # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode # https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/ # torch/nn/parallel/distributed.py#L1040 if ( diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 11c2dffb397a0..0eff11a04f2c9 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -193,7 +193,7 @@ def pre_configure_ddp(self): self._ddp_kwargs["find_unused_parameters"] = True def _register_ddp_hooks(self) -> None: - # currently, DDP communication hooks only work with NCCL backend and SPSD (singlge process single device) mode + # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode # https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/ # torch/nn/parallel/distributed.py#L1040 if ( diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py index 06c8069441706..25845de1aeb8a 100644 --- a/tests/plugins/test_ddp_plugin_with_comm_hook.py +++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py @@ -129,11 +129,6 @@ def test_ddp_spawn_fp16_compress_comm_hook(tmpdir): fast_dev_run=True, ) trainer.fit(model) - trainer_comm_hook = ( - trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook - ) - expected_comm_hook = default.fp16_compress_hook.__qualname__ - assert trainer_comm_hook == expected_comm_hook assert ( trainer.state == TrainerState.FINISHED ), f"Training failed with {trainer.state}" From e32a11d6f902cf26e0ffb8f96c2a8f232c530eed Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Tue, 6 Apr 2021 12:30:17 -0700 Subject: [PATCH 49/49] pre-commit formatting --- CHANGELOG.md | 2 +- pytorch_lightning/plugins/training_type/ddp.py | 9 ++------- pytorch_lightning/plugins/training_type/ddp_spawn.py | 10 +++------- 3 files changed, 6 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c11a3fdfa6ddb..e6a490c2f8aff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,7 +63,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595)) -- Added support for DDP communication hooks ([#6736] (https://github.com/PyTorchLightning/pytorch-lightning/issues/6736)) +- Added support for DDP communication hooks ([#6736](https://github.com/PyTorchLightning/pytorch-lightning/issues/6736)) ### Changed diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 956cb424d7a11..e777fbd3f89c2 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -237,13 +237,8 @@ def pre_configure_ddp(self): def _register_ddp_hooks(self) -> None: # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode - # https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/ - # torch/nn/parallel/distributed.py#L1040 - if ( - _TORCH_GREATER_EQUAL_1_8 - and self.on_gpu - and self._is_single_process_single_device - ): + # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 + if (_TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device): register_ddp_comm_hook( model=self._model, ddp_comm_state=self._ddp_comm_state, diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 0eff11a04f2c9..0da0cc17a67eb 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -33,6 +33,7 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything + if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook @@ -194,13 +195,8 @@ def pre_configure_ddp(self): def _register_ddp_hooks(self) -> None: # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode - # https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/ - # torch/nn/parallel/distributed.py#L1040 - if ( - _TORCH_GREATER_EQUAL_1_8 - and self.on_gpu - and self._is_single_process_single_device - ): + # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 + if (_TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device): register_ddp_comm_hook( model=self._model, ddp_comm_state=self._ddp_comm_state,