From 1aefd92f5d9cc28a3da9337018650b1f585f202a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 03:20:47 +0200 Subject: [PATCH 1/4] Use typing forward references --- pytorch_lightning/core/lightning.py | 69 +++++++++---------- pytorch_lightning/overrides/base.py | 6 +- pytorch_lightning/overrides/data_parallel.py | 4 +- pytorch_lightning/overrides/distributed.py | 4 +- .../plugins/precision/apex_amp.py | 2 +- .../plugins/training_type/deepspeed.py | 6 +- .../plugins/training_type/parallel.py | 7 +- tests/accelerators/test_multi_nodes_gpu.py | 2 +- 8 files changed, 48 insertions(+), 52 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index cf6e25c54f336..9582618323d7c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -324,42 +324,39 @@ def log( ' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`' ) - if self._results is not None: - # TODO: if logged twice fail with crash - - # set the default depending on the fx_name - on_step = self.__auto_choose_log_on_step(on_step) - on_epoch = self.__auto_choose_log_on_epoch(on_epoch) - - assert self._current_fx_name is not None - self.trainer.logger_connector.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) - - # make sure user doesn't introduce logic for multi-dataloaders - if "/dataloader_idx_" in name: - raise MisconfigurationException( - f"Logged key: {name} should not contain information about dataloader_idx." - ) - - value = self.__sync( - value, - sync_fn=self.trainer.training_type_plugin.reduce, - sync_dist=sync_dist, - sync_dist_op=sync_dist_op, - sync_dist_group=sync_dist_group, - device=self.device, - ) + # set the default depending on the fx_name + on_step = self.__auto_choose_log_on_step(on_step) + on_epoch = self.__auto_choose_log_on_epoch(on_epoch) + + assert self._current_fx_name is not None + self.trainer.logger_connector.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) + + # make sure user doesn't introduce logic for multi-dataloaders + if "/dataloader_idx_" in name: + raise MisconfigurationException(f"Logged key: {name} should not contain information about dataloader_idx.") + + value = self.__sync( + value, + sync_fn=self.trainer.training_type_plugin.reduce, + sync_dist=sync_dist, + sync_dist_op=sync_dist_op, + sync_dist_group=sync_dist_group, + device=self.device, + ) - self._results.log( - name, - value, - prog_bar=prog_bar, - logger=logger, - on_step=on_step, - on_epoch=on_epoch, - reduce_fx=reduce_fx, - enable_graph=enable_graph, - dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), - ) + # TODO: if logged twice fail with crash + assert self._results is not None + self._results.log( + name, + value, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + enable_graph=enable_graph, + dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), + ) def log_dict( self, @@ -378,7 +375,7 @@ def log_dict( add_dataloader_idx: bool = True, ) -> None: """ - Log a dictonary of values at once + Log a dictionary of values at once Example:: diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index 88e8ed6375e1b..e086779bec901 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -15,13 +15,13 @@ from torch.nn import DataParallel from torch.nn.parallel import DistributedDataParallel -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): - def __init__(self, pl_module: LightningModule): + def __init__(self, pl_module: 'pl.LightningModule') -> None: """ Wraps the user's LightningModule and redirects the forward call to the appropriate method, either ``training_step``, ``validation_step`` or ``test_step``. @@ -66,7 +66,7 @@ def on_post_move_to_device(self): pass -def unwrap_lightning_module(wrapped_model) -> LightningModule: +def unwrap_lightning_module(wrapped_model) -> 'pl.LightningModule': model = wrapped_model if isinstance(model, (DistributedDataParallel, DataParallel)): model = model.module diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 3d6e527ef95a9..57919db6ab221 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -17,7 +17,7 @@ import torch -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -53,7 +53,7 @@ class LightningParallelModule(_LightningModuleWrapperBase): """ - def __init__(self, pl_module: LightningModule): + def __init__(self, pl_module: 'pl.LightningModule') -> None: super().__init__(pl_module) _ignore_scalar_return_in_dp() diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index 559e1161ce676..d4b1e6ed22d55 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -18,13 +18,13 @@ from torch.nn.parallel import DistributedDataParallel from torch.utils.data import BatchSampler, DistributedSampler, Sampler -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase class LightningDistributedModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: LightningModule): + def __init__(self, pl_module: 'pl.LightningModule') -> None: """ Wraps the user's LightningModule and redirects the forward call to the appropriate method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``. diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 71c2119e734fd..aa3aad7689cf0 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -39,7 +39,7 @@ def __init__(self, amp_level: str = "O2") -> None: def master_params(self, optimizer: Optimizer) -> _PARAMETERS: return amp.master_params(optimizer) - def dispatch(self, trainer: "pl.Trainer") -> None: + def dispatch(self, trainer: 'pl.Trainer') -> None: if not self._connected: accelerator = trainer.accelerator _, accelerator.optimizers = amp.initialize( diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 8dd04aafa6b86..3481986f2102f 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -22,8 +22,8 @@ import torch +import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.ddp import DDPPlugin @@ -51,7 +51,7 @@ def remove_module_hooks(model: torch.nn.Module) -> None: class LightningDeepSpeedModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: LightningModule, precision: int): + def __init__(self, pl_module: 'pl.LightningModule', precision: int) -> None: super().__init__(pl_module) self.precision = precision @@ -378,7 +378,7 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) return distributed_sampler_kwargs - def init_optimizers(self, trainer, model: LightningModule) -> Tuple[List, List, List]: + def init_optimizers(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> Tuple[List, List, List]: # Skip initializing optimizers here as DeepSpeed handles optimizers via config. # User may have specified config options instead in configure_optimizers, but this is handled # via `_initialize_deepspeed_train` diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index a8028e5be1a69..09e48a760e868 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -19,7 +19,7 @@ import torch from torch.nn.parallel import DistributedDataParallel -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin @@ -99,7 +99,7 @@ def torch_distributed_backend(self): return torch_backend @staticmethod - def configure_sync_batchnorm(model: LightningModule) -> LightningModule: + def configure_sync_batchnorm(model: 'pl.LightningModule') -> 'pl.LightningModule': """ Add global batchnorm for a model spread across multiple GPUs and nodes. @@ -112,8 +112,7 @@ def configure_sync_batchnorm(model: LightningModule) -> LightningModule: Return: LightningModule with batchnorm layers synchronized between process groups """ - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - return model + return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @contextmanager def block_backward_sync(self): diff --git a/tests/accelerators/test_multi_nodes_gpu.py b/tests/accelerators/test_multi_nodes_gpu.py index cae257666e390..463307ead8717 100644 --- a/tests/accelerators/test_multi_nodes_gpu.py +++ b/tests/accelerators/test_multi_nodes_gpu.py @@ -125,5 +125,5 @@ def backward(self, loss, optimizer, optimizer_idx): } # we don't want to enable val metrics during steps because it is not something that users should do - # on purpose DO NOT allow step_b... it's silly to monitor val step metrics + # on purpose DO NOT allow b_step... it's silly to monitor val step metrics assert set(trainer.callback_metrics) == {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'} From b9982bc2f9078af2f3ab59a1e469c8be282194d9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 03:29:36 +0200 Subject: [PATCH 2/4] Extend support for logging a collection --- pytorch_lightning/core/lightning.py | 34 +++++++++---------- .../plugins/training_type/ddp2.py | 20 +++++------ pytorch_lightning/plugins/training_type/dp.py | 24 +++++-------- .../connectors/logger_connector/result.py | 10 ------ pytorch_lightning/utilities/types.py | 1 + .../logging_/test_train_loop_logging.py | 14 ++++++++ 6 files changed, 51 insertions(+), 52 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 9582618323d7c..a24dc9d367f87 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -46,7 +46,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import _METRIC, EPOCH_OUTPUT, STEP_OUTPUT +from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache if TYPE_CHECKING: @@ -261,7 +261,7 @@ def forward(self, x): def log( self, name: str, - value: Any, + value: _METRIC_COLLECTION, prog_bar: bool = False, logger: bool = True, on_step: Optional[bool] = None, @@ -324,6 +324,9 @@ def log( ' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`' ) + # check for none values + apply_to_collection(value, type(None), partial(self.__check_none, name, value)) + # set the default depending on the fx_name on_step = self.__auto_choose_log_on_step(on_step) on_epoch = self.__auto_choose_log_on_epoch(on_epoch) @@ -335,16 +338,16 @@ def log( if "/dataloader_idx_" in name: raise MisconfigurationException(f"Logged key: {name} should not contain information about dataloader_idx.") - value = self.__sync( - value, + sync_fn = partial( + self.__sync, sync_fn=self.trainer.training_type_plugin.reduce, sync_dist=sync_dist, sync_dist_op=sync_dist_op, sync_dist_group=sync_dist_group, device=self.device, ) + value = apply_to_collection(value, (torch.Tensor, numbers.Number), sync_fn) - # TODO: if logged twice fail with crash assert self._results is not None self._results.log( name, @@ -360,7 +363,7 @@ def log( def log_dict( self, - dictionary: dict, + dictionary: Dict[str, _METRIC_COLLECTION], prog_bar: bool = False, logger: bool = True, on_step: Optional[bool] = None, @@ -417,29 +420,26 @@ def log_dict( @staticmethod def __sync( - value: _METRIC, + value: Union[torch.Tensor, numbers.Number], sync_fn: Optional[Callable] = None, sync_dist: bool = False, sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, device: torch.device = None, - ) -> _METRIC: + ) -> torch.Tensor: """Sync across workers when using distributed training""" - if not isinstance(value, (torch.Tensor, numbers.Number)): - return value - + if isinstance(value, numbers.Number): + value = torch.tensor(value, device=device, dtype=torch.float) sync_fn = sync_fn or sync_ddp_if_available dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed() if not sync_dist or not dist_available: return value - - # TODO: Find a way to make the reduction only once, so we don't need to clone. - if isinstance(value, torch.Tensor): - value = value.clone() - else: - value = torch.tensor(value, device=device, dtype=torch.float) return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) + @staticmethod + def __check_none(name: str, value: Any, _) -> Any: + raise ValueError(f'`self.log({name}, {value})` was called, but `None` values cannot be logged') + def write_prediction( self, name: str, value: Union[torch.Tensor, List[torch.Tensor]], filename: str = 'predictions.pt' ): diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index ecf6997cba321..185e955135141 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -14,7 +14,8 @@ import torch from pytorch_lightning.plugins.training_type.ddp import DDPPlugin -from pytorch_lightning.trainer.connectors.logger_connector.result import Result +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.types import _METRIC_COLLECTION class DDP2Plugin(DDPPlugin): @@ -34,26 +35,25 @@ def setup(self, model): self.task_idx = self.cluster_environment.local_rank() # the difference to DDP is that we don't call children processes here - def reduce(self, tensor, *args, **kwargs): + def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION: """ - Reduces a tensor from all processes to one aggregated tensor. + Reduces a collection of tensors from all processes. It can be applied to just a single tensor. In DDP2, the reduction here is only across local devices within the node. Args: - tensor: the tensor to sync and reduce + collection: The collection of tensors to sync and reduce. *args: ignored for DDP2 **kwargs: ignored for DDP2 Return: - reduced value, except when the input was not a tensor the output remains is unchanged + Reduced tensor values or the same value if it was not or did not contain a tensor. """ - if isinstance(tensor, Result): - tensor.dp_reduce() - elif isinstance(tensor, torch.Tensor): - tensor = tensor.mean() + def mean(t: torch.Tensor) -> torch.Tensor: + original_dtype = t.dtype + return t.float().mean().to(original_dtype) - return tensor + return apply_to_collection(collection, torch.Tensor, mean) @property def root_device(self): diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index bb6f25a0eed36..18aeb6a451d4a 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -18,8 +18,8 @@ from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin -from pytorch_lightning.trainer.connectors.logger_connector.result import Result from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.types import _METRIC_COLLECTION class DataParallelPlugin(ParallelPlugin): @@ -52,30 +52,24 @@ def setup(self, model): model.to(self.root_device) self._model = DataParallel(LightningParallelModule(model), self.parallel_devices) - def reduce(self, tensor, *args, **kwargs): + def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION: """ - Reduces a tensor from all parallel processes to one aggregated tensor. + Reduces a collection of tensors from all processes. It can be applied to just a single tensor. Args: - tensor: the tensor to sync and reduce + collection: The collection of tensors to sync and reduce. *args: ignored for DP **kwargs: ignored for DP Return: - reduced value, except when the input was not a tensor the output remains is unchanged + Reduced tensor values or the same value if it was not or did not contain a tensor. """ - if isinstance(tensor, Result): - tensor.dp_reduce() - else: + def mean(t: torch.Tensor) -> torch.Tensor: + original_dtype = t.dtype + return t.float().mean().to(original_dtype) - def _reduce(t: torch.Tensor): - dtype_tensor = t.dtype - return t.float().mean().type(dtype_tensor) - - tensor = apply_to_collection(tensor, torch.Tensor, _reduce) - - return tensor + return apply_to_collection(collection, torch.Tensor, mean) @property def root_device(self): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index c55fb14a7eed4..c759946b4b4d0 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -461,16 +461,6 @@ def reduce_across_time(cls, time_outputs): result['meta'] = meta return result - def dp_reduce(self): - for k, value in self.items(): - if k == 'meta' or isinstance(value, Metric): - continue - - if isinstance(value, list): - value = torch.tensor(value) - - self[k] = value.mean(dim=-1) - @property def should_reduce_on_epoch_end(self) -> bool: return self['meta']['_internal']['_reduce_on_epoch'] diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 8a81040af07db..f209287358f84 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -23,6 +23,7 @@ from torchmetrics import Metric _METRIC = Union[Metric, torch.Tensor, Number] +_METRIC_COLLECTION = Union[_METRIC, Dict[str, '_METRIC_COLLECTION']] STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]] EPOCH_OUTPUT = List[STEP_OUTPUT] _EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 4fee2ecc37f52..7fbbf5805bfb2 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -769,3 +769,17 @@ def validation_step(self, batch, batch_idx): assert trainer.callback_metrics["val_acc"] == 8 / 32. assert "train_loss" in trainer.callback_metrics + + +@pytest.mark.parametrize('value', [None, {'a': {'b': None}}]) +def test_log_none_raises(tmpdir, value): + + class TestModel(BoringModel): + + def training_step(self, *args): + self.log("foo", value) + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + model = TestModel() + with pytest.raises(ValueError, match=rf"self.log\(foo, {value}\)` was called"): + trainer.fit(model) From 064b27fc6e852843d50e268b5b74e8f4f3e1572b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 03:35:54 +0200 Subject: [PATCH 3/4] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b75e439f1662..3a1678f177b3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -91,6 +91,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the behaviour when logging evaluation step metrics to no longer append `/epoch_*` to the metric name ([#7351](https://github.com/PyTorchLightning/pytorch-lightning/pull/7351)) +- Raise `ValueError` when a `None` value is `self.log`-ed ([#7771](https://github.com/PyTorchLightning/pytorch-lightning/pull/7771)) + + - Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([#7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026)) From 1da392db4d98480e61b45bfd6536cbbae9787437 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 31 May 2021 19:04:17 +0200 Subject: [PATCH 4/4] Remove recursive type --- pytorch_lightning/utilities/types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index f209287358f84..4a98956b71c57 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -23,7 +23,8 @@ from torchmetrics import Metric _METRIC = Union[Metric, torch.Tensor, Number] -_METRIC_COLLECTION = Union[_METRIC, Dict[str, '_METRIC_COLLECTION']] +# real type is `Union[_METRIC, Dict[str, '_METRIC_COLLECTION']]` but Sphinx fails with `RecursionError` +_METRIC_COLLECTION = Union[_METRIC, Dict[str, _METRIC]] STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]] EPOCH_OUTPUT = List[STEP_OUTPUT] _EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader