From 5e10b93a6cbf90c2d49dbd068f2307d92cd61b40 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 8 Nov 2021 16:38:23 +0530 Subject: [PATCH 01/34] extract batch size only when it's required --- pytorch_lightning/core/lightning.py | 2 +- .../loops/epoch/training_epoch_loop.py | 4 +++ .../connectors/logger_connector/result.py | 25 +++++++------------ tests/loops/batch/test_truncated_bptt.py | 1 + 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 2f359f00aba75..2b0d4a75d3596 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1633,7 +1633,7 @@ def optimizer_step( pg["lr"] = lr_scale * self.learning_rate # update params - optimizer.step(closure=optimizer_closure) + optimizer.step(clo, batch_sizesure=optimizer_closure) """ optimizer.step(closure=optimizer_closure) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 21d89a8be8b52..3a108e2f7b732 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -194,6 +194,10 @@ def advance(self, *args: Any, **kwargs: Any) -> None: with self.trainer.profiler.profile("run_training_batch"): batch_output = self.batch_loop.run(batch, batch_idx) + import pdb + + pdb.set_trace() + self.trainer._results.batch_size = batch_size self.batch_progress.increment_processed() diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index f798cf3ee2b82..bba449235b7f7 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -357,6 +357,7 @@ def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = super().__init__() self.training = training self._batch_size = torch.tensor(1, device=device) + self._current_batch = None self.device: Optional[Union[str, torch.device]] = device @property @@ -371,13 +372,12 @@ def append_fn(v: ResultMetric) -> None: return o @property - def batch_size(self) -> torch.Tensor: - # performance: cache the `batch_size` tensor instead of re-creating it - return self._batch_size + def current_batch(self) -> Any: + return self._current_batch - @batch_size.setter - def batch_size(self, value: int) -> None: - self._batch_size = torch.tensor(value, device=self.device) + @current_batch.setter + def current_batch(self, data: Any) -> None: + self._current_batch = data def log( self, @@ -438,9 +438,10 @@ def log( f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed" ) - if batch_size is not None: - self.batch_size = batch_size + if batch_size is None and on_epoch: + batch_size = extract_batch_size(self._current_batch) + self.batch_size = batch_size self.update_metrics(key, value) def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None: @@ -555,14 +556,6 @@ def fn(item: ResultMetric) -> None: apply_to_collection(self, ResultMetric, fn) - def extract_batch_size(self, batch: Any) -> int: - try: - batch_size = extract_batch_size(batch) - except RecursionError: - batch_size = 1 - self.batch_size = batch_size # the setter converts it to `Tensor` - return batch_size - def to(self, *args: Any, **kwargs: Any) -> "ResultCollection": """Move all data to the given device.""" self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs)) diff --git a/tests/loops/batch/test_truncated_bptt.py b/tests/loops/batch/test_truncated_bptt.py index 55adbc618b9f9..e27f21d5450e8 100644 --- a/tests/loops/batch/test_truncated_bptt.py +++ b/tests/loops/batch/test_truncated_bptt.py @@ -37,6 +37,7 @@ def configure_optimizers(self): def training_step(self, batch, batch_idx, hiddens): x, y = batch + breakpoint() pred, hiddens = self.lstm(x, hiddens) loss = F.mse_loss(pred, y) return {"loss": loss, "hiddens": hiddens} From 8941e32a4126fbae7bd67e9c5ad6b68df4eb5c98 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 8 Nov 2021 18:14:44 +0530 Subject: [PATCH 02/34] revamp log --- pytorch_lightning/core/lightning.py | 248 ++++++++++-------- .../loops/epoch/training_epoch_loop.py | 10 +- .../logger_connector/logger_connector.py | 9 - .../connectors/logger_connector/result.py | 16 +- 4 files changed, 139 insertions(+), 144 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 2b0d4a75d3596..ddc8cfb7edbe7 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -46,6 +46,7 @@ ) from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.data import extract_batch_size from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import get_model_size_mb @@ -358,122 +359,25 @@ def log( rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which would produce a deadlock as not all processes would perform this log call. """ - if tbptt_reduce_fx is not None: - rank_zero_deprecation( - "`self.log(tbptt_reduce_fx=...)` is no longer supported. The flag will be removed in v1.6." - " Please, open a discussion explaining your use-case in" - " `https://github.com/PyTorchLightning/pytorch-lightning/discussions`" - ) - if tbptt_pad_token is not None: - rank_zero_deprecation( - "`self.log(tbptt_pad_token=...)` is no longer supported. The flag will be removed in v1.6." - " Please, open a discussion explaining your use-case in" - " `https://github.com/PyTorchLightning/pytorch-lightning/discussions`" - ) - if sync_dist_op is not None: - rank_zero_deprecation( - f"`self.log(sync_dist_op='{sync_dist_op}')` is deprecated and will be removed in v.1.6." - f" Use `self.log(reduce_fx={sync_dist_op})` instead." - ) - if reduce_fx == "default": - reduce_fx = sync_dist_op - elif reduce_fx == "default": - reduce_fx = "mean" - - # check for invalid values - apply_to_collection(value, dict, self.__check_not_nested, name) - apply_to_collection( - value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor, dict) - ) - - # set the default depending on the fx_name - on_step = self.__auto_choose_log_on_step(on_step) - on_epoch = self.__auto_choose_log_on_epoch(on_epoch) - - if self.trainer is None: - # not an error to support testing the `*_step` methods without a `Trainer` reference - rank_zero_warn( - "You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet." - " This is most likely because the model hasn't been passed to the `Trainer`" - ) - return - results = self.trainer._results - if results is None: - raise MisconfigurationException( - "You are trying to `self.log()` but the loop `ResultCollection` is not registered" - " yet. This is most likely because you are trying to log in a `predict` hook," - " but it doesn't support logging" - ) - if self._current_fx_name is None: - raise MisconfigurationException( - "You are trying to `self.log()` but it is not managed by the `Trainer` control flow" - ) - _FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) - - # make sure user doesn't introduce logic for multi-dataloaders - if "/dataloader_idx_" in name: - raise MisconfigurationException( - f"You called `self.log` with the key `{name}`" - " but it should not contain information about `dataloader_idx`" - ) - - value = apply_to_collection(value, numbers.Number, self.__to_tensor) - - if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name): - # if we started a new epoch (running it's first batch) the hook name has changed - # reset any tensors for the new hook name - results.reset(metrics=False, fx=self._current_fx_name) - - if metric_attribute is None and isinstance(value, Metric): - if self._metric_attributes is None: - # compute once - self._metric_attributes = { - id(module): name for name, module in self.named_modules() if isinstance(module, Metric) - } - if not self._metric_attributes: - raise MisconfigurationException( - "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." - " You can fix this by setting an attribute for the metric in your `LightningModule`." - ) - # try to find the passed metric in the LightningModule - metric_attribute = self._metric_attributes.get(id(value), None) - if metric_attribute is None: - raise MisconfigurationException( - "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." - f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name` is one" - f" of {list(self._metric_attributes.values())}" - ) - - if ( - self.trainer.training - and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True) - and batch_size is None - ): - raise MisconfigurationException( - "With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided." - ) - - results.log( - self._current_fx_name, - name, - value, + self.log_dict( + dictionary={name: value}, prog_bar=prog_bar, logger=logger, on_step=on_step, on_epoch=on_epoch, reduce_fx=reduce_fx, enable_graph=enable_graph, - dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), - batch_size=batch_size, - sync_dist=sync_dist and distributed_available(), - sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp, + sync_dist=sync_dist, sync_dist_group=sync_dist_group, + sync_dist_op=sync_dist_op, + tbptt_pad_token=tbptt_pad_token, + tbptt_reduce_fx=tbptt_reduce_fx, + add_dataloader_idx=add_dataloader_idx, + batch_size=batch_size, metric_attribute=metric_attribute, rank_zero_only=rank_zero_only, ) - self.trainer.logger_connector._current_fx = self._current_fx_name - def log_dict( self, dictionary: Mapping[str, _METRIC_COLLECTION], @@ -490,6 +394,7 @@ def log_dict( sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, batch_size: Optional[int] = None, + metric_attribute: Optional[str] = None, rank_zero_only: Optional[bool] = None, ) -> None: """Log a dictionary of values at once. @@ -516,29 +421,142 @@ def log_dict( each dataloader to not mix values batch_size: Current batch_size. This will be directly inferred from the loaded batch, but some data structures might need to explicitly provide it. + metric_attribute: To restore the metric state, Lightning requires the reference of the + :class:`torchmetrics.Metric` in your model. This is found automatically if it is a model attribute. rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which would produce a deadlock as not all processes would perform this log call. """ - for k, v in dictionary.items(): - self.log( - name=k, - value=v, + if self.trainer is None: + # not an error to support testing the `*_step` methods without a `Trainer` reference + rank_zero_warn( + "You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet." + " This is most likely because the model hasn't been passed to the `Trainer`" + ) + return + + if ( + self.trainer.training + and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True) + and batch_size is None + ): + raise MisconfigurationException( + "With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided." + ) + + results = self.trainer._results + if results is None: + raise MisconfigurationException( + "You are trying to `self.log()` but the loop `ResultCollection` is not registered" + " yet. This is most likely because you are trying to log in a `predict` hook," + " but it doesn't support logging" + ) + if self._current_fx_name is None: + raise MisconfigurationException( + "You are trying to `self.log()` but it is not managed by the `Trainer` control flow" + ) + _FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) + + if isinstance(reduce_fx, str): + reduce_fx = reduce_fx.lower() + + if tbptt_reduce_fx is not None: + rank_zero_deprecation( + "`self.log(tbptt_reduce_fx=...)` is no longer supported. The flag will be removed in v1.6." + " Please, open a discussion explaining your use-case in" + " `https://github.com/PyTorchLightning/pytorch-lightning/discussions`" + ) + if tbptt_pad_token is not None: + rank_zero_deprecation( + "`self.log(tbptt_pad_token=...)` is no longer supported. The flag will be removed in v1.6." + " Please, open a discussion explaining your use-case in" + " `https://github.com/PyTorchLightning/pytorch-lightning/discussions`" + ) + if sync_dist_op is not None: + rank_zero_deprecation( + f"`self.log(sync_dist_op='{sync_dist_op}')` is deprecated and will be removed in v.1.6." + f" Use `self.log(reduce_fx={sync_dist_op})` instead." + ) + if reduce_fx == "default": + reduce_fx = sync_dist_op + elif reduce_fx == "default": + reduce_fx = "mean" + + # set the default depending on the fx_name + on_step = self.__auto_choose_log_on_step(on_step) + on_epoch = self.__auto_choose_log_on_epoch(on_epoch) + + if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name): + # if we started a new epoch (running it's first batch) the hook name has changed + # reset any tensors for the new hook name + results.reset(metrics=False, fx=self._current_fx_name) + + # extract batch size if it's None whenever it's required + if batch_size is None: + if ( + on_epoch + and True in _FxValidator.functions[self._current_fx_name]["on_step"] + and reduce_fx in ("mean", "avg") + ): + batch_size = extract_batch_size(self.trainer._results.current_batch) + else: + batch_size = 1 + + for name, value in dictionary.items(): + # make sure user doesn't introduce logic for multi-dataloaders + if "/dataloader_idx_" in name: + raise MisconfigurationException( + f"You called `self.log` with the key `{name!r}`" + " but it should not contain information about `dataloader_idx`" + ) + + # check for invalid values + apply_to_collection(value, dict, self.__check_not_nested, name) + apply_to_collection( + value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor, dict) + ) + value = apply_to_collection(value, numbers.Number, self.__to_tensor) + + if metric_attribute is None and isinstance(value, Metric): + if self._metric_attributes is None: + # compute once + self._metric_attributes = { + id(module): name for name, module in self.named_modules() if isinstance(module, Metric) + } + if not self._metric_attributes: + raise MisconfigurationException( + "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." + " You can fix this by setting an attribute for the metric in your `LightningModule`." + ) + # try to find the passed metric in the LightningModule + metric_attribute = self._metric_attributes.get(id(value), None) + if metric_attribute is None: + raise MisconfigurationException( + "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." + f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name`" + f" is one of {list(self._metric_attributes.values())}." + ) + + results.log( + self._current_fx_name, + name, + value, prog_bar=prog_bar, logger=logger, on_step=on_step, on_epoch=on_epoch, reduce_fx=reduce_fx, enable_graph=enable_graph, - sync_dist=sync_dist, - sync_dist_group=sync_dist_group, - sync_dist_op=sync_dist_op, - tbptt_pad_token=tbptt_pad_token, - tbptt_reduce_fx=tbptt_reduce_fx, - add_dataloader_idx=add_dataloader_idx, + dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), batch_size=batch_size, + sync_dist=sync_dist and distributed_available(), + sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp, + sync_dist_group=sync_dist_group, + metric_attribute=metric_attribute, rank_zero_only=rank_zero_only, ) + self.trainer.logger_connector._current_fx = self._current_fx_name + @staticmethod def __check_not_nested(value: dict, name: str) -> dict: # self-imposed restriction. for simplicity diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 3a108e2f7b732..3794ac40b7fc0 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -161,9 +161,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.batch_progress.increment_ready() - # cache the batch size value to avoid extracting it again after the batch loop runs as the value will be - # different if tbptt is enabled - batch_size = self.trainer.logger_connector.on_batch_start(batch_idx, batch) + self.trainer._results.current_batch = batch if batch is None: self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") @@ -194,12 +192,6 @@ def advance(self, *args: Any, **kwargs: Any) -> None: with self.trainer.profiler.profile("run_training_batch"): batch_output = self.batch_loop.run(batch, batch_idx) - import pdb - - pdb.set_trace() - - self.trainer._results.batch_size = batch_size - self.batch_progress.increment_processed() # update non-plateau LR schedulers diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 37fcb06a1dc24..7471b5c823dac 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -253,14 +253,6 @@ def _log_gpus_metrics(self) -> None: Utilities and properties """ - def on_new_batch(self, batch: Any) -> int: - # when the user requests `dataloader_iter`, we can't track the batch_size - # and this is left to user responsibility. - if not isinstance(batch, pl.utilities.fetching.StepFuncDataLoaderIter): - assert self.trainer._results is not None - return self.trainer._results.extract_batch_size(batch) - return 1 - def on_epoch_start(self) -> None: self._epoch_end_reached = False @@ -274,7 +266,6 @@ def epoch_end_reached(self) -> None: self._batch_idx = None self._split_idx = None assert self.trainer._results is not None - self.trainer._results.batch_size = 1 def on_epoch_end(self) -> None: assert self._epoch_end_reached diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index bba449235b7f7..67246281c4499 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -23,7 +23,6 @@ from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device -from pytorch_lightning.utilities.data import extract_batch_size from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.metrics import metrics_to_scalars @@ -356,7 +355,6 @@ class ResultCollection(dict): def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None: super().__init__() self.training = training - self._batch_size = torch.tensor(1, device=device) self._current_batch = None self.device: Optional[Union[str, torch.device]] = device @@ -394,7 +392,7 @@ def log( sync_dist_fn: Callable = _Sync.no_op, sync_dist_group: Optional[Any] = None, dataloader_idx: Optional[int] = None, - batch_size: Optional[int] = None, + batch_size: int = 1, metric_attribute: Optional[str] = None, rank_zero_only: bool = False, ) -> None: @@ -438,11 +436,8 @@ def log( f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed" ) - if batch_size is None and on_epoch: - batch_size = extract_batch_size(self._current_batch) - - self.batch_size = batch_size - self.update_metrics(key, value) + batch_size = torch.tensor(batch_size, device=torch.device) + self.update_metrics(key, value, batch_size) def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None: """Create one ResultMetric object per value. @@ -459,10 +454,10 @@ def fn(v: _IN_METRIC) -> ResultMetric: value = ResultMetricCollection(value) self[key] = value - def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None: + def update_metrics(self, key: str, value: _METRIC_COLLECTION, batch_size: torch.Tensor) -> None: def fn(result_metric: ResultMetric, v: ResultMetric) -> None: # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl` - result_metric.forward(v.to(self.device), self.batch_size) + result_metric.forward(v.to(self.device), batch_size) result_metric.has_reset = False apply_to_collections(self[key], value, ResultMetric, fn) @@ -560,7 +555,6 @@ def to(self, *args: Any, **kwargs: Any) -> "ResultCollection": """Move all data to the given device.""" self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs)) - self._batch_size = self._batch_size.to(*args, **kwargs) if "device" in kwargs: self.device = kwargs["device"] return self From 51a7d36aa69d6459d6205878937b02c504bf980f Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 8 Nov 2021 18:31:18 +0530 Subject: [PATCH 03/34] fix --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 2 +- .../trainer/connectors/logger_connector/logger_connector.py | 3 +-- .../trainer/connectors/logger_connector/result.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 3794ac40b7fc0..8ddca3ad505e8 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -161,7 +161,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.batch_progress.increment_ready() - self.trainer._results.current_batch = batch + self.trainer.logger_connector.on_batch_start(batch_idx, batch) if batch is None: self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 7471b5c823dac..47b0920ad62a9 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -210,7 +210,6 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]: def on_train_split_start(self, split_idx: int, split_batch: Any) -> None: self._split_idx = split_idx - self.on_new_batch(split_batch) def update_train_step_metrics(self) -> None: if self.trainer.fit_loop._should_accumulate() and self.trainer.lightning_module.automatic_optimization: @@ -259,7 +258,7 @@ def on_epoch_start(self) -> None: def on_batch_start(self, batch_idx: int, batch: Any) -> int: self._batch_idx = batch_idx self._epoch_end_reached = False - return self.on_new_batch(batch) + self.trainer._results.current_batch = batch def epoch_end_reached(self) -> None: self._epoch_end_reached = True diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 67246281c4499..eaf4bfd904f86 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -436,7 +436,7 @@ def log( f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed" ) - batch_size = torch.tensor(batch_size, device=torch.device) + batch_size = torch.tensor(batch_size, device=self.device) self.update_metrics(key, value, batch_size) def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None: From c2a3626946b4b85a69ca3c614560caeaada964b6 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 8 Nov 2021 18:48:55 +0530 Subject: [PATCH 04/34] fix --- pytorch_lightning/core/lightning.py | 8 ++++---- tests/trainer/logging_/test_train_loop_logging.py | 6 +----- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ddc8cfb7edbe7..c621b61f01cc9 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -454,6 +454,10 @@ def log_dict( raise MisconfigurationException( "You are trying to `self.log()` but it is not managed by the `Trainer` control flow" ) + + # set the default depending on the fx_name + on_step = self.__auto_choose_log_on_step(on_step) + on_epoch = self.__auto_choose_log_on_epoch(on_epoch) _FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) if isinstance(reduce_fx, str): @@ -481,10 +485,6 @@ def log_dict( elif reduce_fx == "default": reduce_fx = "mean" - # set the default depending on the fx_name - on_step = self.__auto_choose_log_on_step(on_step) - on_epoch = self.__auto_choose_log_on_epoch(on_epoch) - if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name): # if we started a new epoch (running it's first batch) the hook name has changed # reset any tensors for the new hook name diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 6cad94017177e..c42c5a0f06bab 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -616,7 +616,7 @@ def training_step(self, batch, batch_idx): trainer = Trainer(default_root_dir=tmpdir) model = TestModel() - with pytest.raises(MisconfigurationException, match="`self.log` with the key `foo/dataloader_idx_0`"): + with pytest.raises(MisconfigurationException, match="`self.log` with the key `'foo/dataloader_idx_0'`"): trainer.fit(model) class TestModel(BoringModel): @@ -717,19 +717,15 @@ def on_validation_epoch_end(self): assert all(v == 3 for v in self.trainer.callback_metrics.values()) def on_train_batch_start(self, batch, batch_idx): - assert self.trainer._results.batch_size == 2 self.log("on_train_batch_start", 1.0, reduce_fx="sum") def on_train_batch_end(self, outputs, batch, batch_idx): - assert self.trainer._results.batch_size == 2 self.log("on_train_batch_end", 1.0, reduce_fx="sum") def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): - assert self.trainer._results.batch_size == 2 self.log("on_validation_batch_start", 1.0, reduce_fx="sum") def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - assert self.trainer._results.batch_size == 2 self.log("on_validation_batch_end", 1.0, reduce_fx="sum") def training_epoch_end(self, *_) -> None: From be5882c76e2736b832318bb9e14186cbbe331038 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 8 Nov 2021 19:48:58 +0530 Subject: [PATCH 05/34] update tests --- .../trainer/connectors/logger_connector/result.py | 1 + tests/loops/batch/test_truncated_bptt.py | 1 - tests/loops/test_loop_state_dict.py | 5 ----- 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index eaf4bfd904f86..380ade3f74c06 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -583,6 +583,7 @@ def __repr__(self) -> str: def __getstate__(self, drop_value: bool = True) -> dict: d = self.__dict__.copy() + del d["_current_batch"] # all the items should be either `ResultMetric`s or `ResultMetricCollection`s items = {k: v.__getstate__(drop_value=drop_value) for k, v in self.items()} return {**d, "items": items} diff --git a/tests/loops/batch/test_truncated_bptt.py b/tests/loops/batch/test_truncated_bptt.py index e27f21d5450e8..55adbc618b9f9 100644 --- a/tests/loops/batch/test_truncated_bptt.py +++ b/tests/loops/batch/test_truncated_bptt.py @@ -37,7 +37,6 @@ def configure_optimizers(self): def training_step(self, batch, batch_idx, hiddens): x, y = batch - breakpoint() pred, hiddens = self.lstm(x, hiddens) loss = F.mse_loss(pred, y) return {"loss": loss, "hiddens": hiddens} diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 717d625f6c44e..acd00e3130c11 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -14,7 +14,6 @@ from unittest.mock import Mock import pytest -import torch from pytorch_lightning.loops import FitLoop from pytorch_lightning.trainer.trainer import Trainer @@ -81,13 +80,11 @@ def test_loops_state_dict_structure(): }, "epoch_loop.val_loop._results": { "training": False, - "_batch_size": torch.tensor(1), "device": None, "items": {}, }, "epoch_loop._results": { "training": True, - "_batch_size": torch.tensor(1), "device": None, "items": {}, }, @@ -107,7 +104,6 @@ def test_loops_state_dict_structure(): }, "_results": { "training": False, - "_batch_size": torch.tensor(1), "device": None, "items": {}, }, @@ -123,7 +119,6 @@ def test_loops_state_dict_structure(): }, "_results": { "training": False, - "_batch_size": torch.tensor(1), "device": None, "items": {}, }, From bba8055d0193b755a3b273a931849cdac865c845 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 8 Nov 2021 19:54:11 +0530 Subject: [PATCH 06/34] chlog --- CHANGELOG.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ac4b5e2468a69..f63e1f5876197 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -106,14 +106,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed issue with pickling `CSVLogger` after a call to `CSVLogger.save` ([#10388](https://github.com/PyTorchLightning/pytorch-lightning/pull/10388)) -- - -- +- Fixed `extract_batch_size` call only when its required ([#10408](https://github.com/PyTorchLightning/pytorch-lightning/pull/10408)) -- - ## [1.5.0] - 2021-11-02 From 308f9c0a8bb7946e94538a7b146e719ffe46f995 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 9 Nov 2021 16:37:56 +0530 Subject: [PATCH 07/34] fix --- CHANGELOG.md | 2 +- pytorch_lightning/core/lightning.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f63e1f5876197..cba3b7a4f7854 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -107,7 +107,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed issue with pickling `CSVLogger` after a call to `CSVLogger.save` ([#10388](https://github.com/PyTorchLightning/pytorch-lightning/pull/10388)) -- Fixed `extract_batch_size` call only when its required ([#10408](https://github.com/PyTorchLightning/pytorch-lightning/pull/10408)) +- Fixed the call to `extract_batch_size` only when its required ([#10408](https://github.com/PyTorchLightning/pytorch-lightning/pull/10408)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c621b61f01cc9..e4c2613c000eb 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1651,7 +1651,7 @@ def optimizer_step( pg["lr"] = lr_scale * self.learning_rate # update params - optimizer.step(clo, batch_sizesure=optimizer_closure) + optimizer.step(closure=optimizer_closure) """ optimizer.step(closure=optimizer_closure) From 04314b82518496dfd5a2317059a3e5936a05190e Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 9 Nov 2021 18:29:06 +0530 Subject: [PATCH 08/34] add test --- .../logging_/test_train_loop_logging.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index c42c5a0f06bab..34c94e3ba63ea 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -747,3 +747,38 @@ def validation_epoch_end(self, *_) -> None: train_data = DataLoader(RandomDataset(32, 64), batch_size=2) val_data = DataLoader(RandomDataset(32, 64), batch_size=2) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) + + +def test_no_batch_size_extraction_with_specifying_explictly(tmpdir): + batch_size = BoringModel().train_dataloader().batch_size + 10 + fast_dev_run = 2 + log_val = 7.0 + + class CustomBoringModel(BoringModel): + def on_before_batch_transfer(self, batch, *args, **kwargs): + # This is an ambiguous batch which have multiple potential batch sizes + if self.trainer.training: + batch = {"batch1": torch.randn(batch_size + 10, 10), "batch2": batch} + return batch + + def training_step(self, batch, batch_idx): + self.log("step_log_val", log_val, on_epoch=False) + self.log("epoch_log_val", log_val, batch_size=batch_size, on_step=False, on_epoch=True) + self.log("epoch_sum_log_val", log_val, on_epoch=True, reduce_fx="sum") + return super().training_step(batch["batch2"], batch_idx) + + def training_epoch_end(self, *args, **kwargs): + results = self.trainer._results + assert results["training_step.step_log_val"].value == log_val + assert results["training_step.step_log_val"].cumulated_batch_size == 0 + assert results["training_step.epoch_log_val"].value == log_val * batch_size * fast_dev_run + assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size * fast_dev_run + assert results["training_step.epoch_sum_log_val"].value == log_val * fast_dev_run + + model = CustomBoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=fast_dev_run) + + with pytest.warns(None) as record: + trainer.fit(model) + + assert not any("Trying to infer the `batch_size`" in warn.message.args[0] for warn in record.list) From 46a8861da7b79fdeb787b6b42018724b803de6f1 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 9 Nov 2021 18:34:19 +0530 Subject: [PATCH 09/34] mypy --- .../trainer/connectors/logger_connector/logger_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 47b0920ad62a9..1ac2543542099 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -255,7 +255,7 @@ def _log_gpus_metrics(self) -> None: def on_epoch_start(self) -> None: self._epoch_end_reached = False - def on_batch_start(self, batch_idx: int, batch: Any) -> int: + def on_batch_start(self, batch_idx: int, batch: Any) -> None: self._batch_idx = batch_idx self._epoch_end_reached = False self.trainer._results.current_batch = batch From 9b28e44c4823868d17a46a18dbfa40d20ce81bbc Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 9 Nov 2021 18:48:48 +0530 Subject: [PATCH 10/34] mypy --- .../trainer/connectors/logger_connector/logger_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 1ac2543542099..5a346667f46bc 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -258,6 +258,7 @@ def on_epoch_start(self) -> None: def on_batch_start(self, batch_idx: int, batch: Any) -> None: self._batch_idx = batch_idx self._epoch_end_reached = False + assert self.trainer._results is not None self.trainer._results.current_batch = batch def epoch_end_reached(self) -> None: From 077296725b8ebec55bd1e05d206b24988a40a507 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 9 Nov 2021 19:36:53 +0530 Subject: [PATCH 11/34] deref current batch --- .../trainer/connectors/logger_connector/logger_connector.py | 1 + .../trainer/connectors/logger_connector/result.py | 1 - tests/loops/test_loop_state_dict.py | 4 ++++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 5a346667f46bc..0388518415257 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -281,6 +281,7 @@ def on_batch_end(self) -> None: self._progress_bar_metrics.update(metrics["pbar"]) self._callback_metrics.update(metrics["callback"]) self._logged_metrics.update(metrics["log"]) + self.trainer._results.current_batch = None def should_reset_tensors(self, fx: str) -> bool: is_different_fx = self._current_fx != fx diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 380ade3f74c06..eaf4bfd904f86 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -583,7 +583,6 @@ def __repr__(self) -> str: def __getstate__(self, drop_value: bool = True) -> dict: d = self.__dict__.copy() - del d["_current_batch"] # all the items should be either `ResultMetric`s or `ResultMetricCollection`s items = {k: v.__getstate__(drop_value=drop_value) for k, v in self.items()} return {**d, "items": items} diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index acd00e3130c11..a4f9e1824a500 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -79,11 +79,13 @@ def test_loops_state_dict_structure(): "is_last_batch": False, }, "epoch_loop.val_loop._results": { + "_current_batch": None, "training": False, "device": None, "items": {}, }, "epoch_loop._results": { + "_current_batch": None, "training": True, "device": None, "items": {}, @@ -103,6 +105,7 @@ def test_loops_state_dict_structure(): "is_last_batch": False, }, "_results": { + "_current_batch": None, "training": False, "device": None, "items": {}, @@ -118,6 +121,7 @@ def test_loops_state_dict_structure(): "is_last_batch": False, }, "_results": { + "_current_batch": None, "training": False, "device": None, "items": {}, From c3567ffcdbec4c2ee987713a51d5361d9d4de8e5 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 9 Nov 2021 19:40:59 +0530 Subject: [PATCH 12/34] mypy --- .../trainer/connectors/logger_connector/logger_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 0388518415257..a31b0ca01c42a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -281,6 +281,7 @@ def on_batch_end(self) -> None: self._progress_bar_metrics.update(metrics["pbar"]) self._callback_metrics.update(metrics["callback"]) self._logged_metrics.update(metrics["log"]) + assert self.trainer._results is not None self.trainer._results.current_batch = None def should_reset_tensors(self, fx: str) -> bool: From 1484284fbd9f0011452d67ffb11382700fd6bd6e Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 9 Nov 2021 21:05:45 +0530 Subject: [PATCH 13/34] rev --- CHANGELOG.md | 11 +- pytorch_lightning/core/lightning.py | 217 +++++++++--------- .../logging_/test_train_loop_logging.py | 2 +- 3 files changed, 116 insertions(+), 114 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 19f14f358c447..ff0e9cd9293ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -116,15 +116,20 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the logging with `on_step=True` in epoch-level hooks causing unintended side-effects. Logging with `on_step=True` in epoch-level hooks will now correctly raise an error ([#10409](https://github.com/PyTorchLightning/pytorch-lightning/pull/10409)) -- Fixed the call to `extract_batch_size` only when its required ([#10408](https://github.com/PyTorchLightning/pytorch-lightning/pull/10408)) - - - Fixed dataloader workers with `persistent_workers` being deleted on every iteration ([#10434](https://github.com/PyTorchLightning/pytorch-lightning/pull/10434)) - Fixed an issue where the model wrapper in Lite converted non-floating point tensors to float ([#10429](https://github.com/PyTorchLightning/pytorch-lightning/pull/10429)) +- Fixed the call to `extract_batch_size` only when its required ([#10408](https://github.com/PyTorchLightning/pytorch-lightning/pull/10408)) + + +- + + +- + ## [1.5.0] - 2021-11-02 diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 116f08438b02a..aae45fba0da9d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -356,25 +356,117 @@ def log( rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which would produce a deadlock as not all processes would perform this log call. """ - self.log_dict( - dictionary={name: value}, + # check for invalid values + apply_to_collection(value, dict, self.__check_not_nested, name) + apply_to_collection( + value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor, dict) + ) + + # set the default depending on the fx_name + on_step = self.__auto_choose_log_on_step(on_step) + on_epoch = self.__auto_choose_log_on_epoch(on_epoch) + + if self.trainer is None: + # not an error to support testing the `*_step` methods without a `Trainer` reference + rank_zero_warn( + "You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet." + " This is most likely because the model hasn't been passed to the `Trainer`" + ) + return + results = self.trainer._results + if results is None: + raise MisconfigurationException( + "You are trying to `self.log()` but the loop `ResultCollection` is not registered" + " yet. This is most likely because you are trying to log in a `predict` hook," + " but it doesn't support logging" + ) + if self._current_fx_name is None: + raise MisconfigurationException( + "You are trying to `self.log()` but it is not managed by the `Trainer` control flow" + ) + _FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) + + # make sure user doesn't introduce logic for multi-dataloaders + if "/dataloader_idx_" in name: + raise MisconfigurationException( + f"You called `self.log` with the key `{name}`" + " but it should not contain information about `dataloader_idx`" + ) + + value = apply_to_collection(value, numbers.Number, self.__to_tensor) + + if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name): + # if we started a new epoch (running it's first batch) the hook name has changed + # reset any tensors for the new hook name + results.reset(metrics=False, fx=self._current_fx_name) + + if metric_attribute is None and isinstance(value, Metric): + if self._metric_attributes is None: + # compute once + self._metric_attributes = { + id(module): name for name, module in self.named_modules() if isinstance(module, Metric) + } + if not self._metric_attributes: + raise MisconfigurationException( + "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." + " You can fix this by setting an attribute for the metric in your `LightningModule`." + ) + # try to find the passed metric in the LightningModule + metric_attribute = self._metric_attributes.get(id(value), None) + if metric_attribute is None: + raise MisconfigurationException( + "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." + f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name` is one" + f" of {list(self._metric_attributes.values())}" + ) + + if ( + self.trainer.training + and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True) + and batch_size is None + ): + raise MisconfigurationException( + "With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided." + ) + + if isinstance(reduce_fx, str): + reduce_fx = reduce_fx.lower() + + # extract batch size if it is None and whenever it is required + if batch_size is None: + if ( + on_epoch + and True in _FxValidator.functions[self._current_fx_name]["on_step"] + and reduce_fx in ("mean", "avg") + ): + try: + batch_size = extract_batch_size(self.trainer._results.current_batch) + except RecursionError: + batch_size = 1 + else: + batch_size = 1 + + results.log( + self._current_fx_name, + name, + value, prog_bar=prog_bar, logger=logger, on_step=on_step, on_epoch=on_epoch, reduce_fx=reduce_fx, enable_graph=enable_graph, - sync_dist=sync_dist, - sync_dist_group=sync_dist_group, - sync_dist_op=sync_dist_op, - tbptt_pad_token=tbptt_pad_token, - tbptt_reduce_fx=tbptt_reduce_fx, - add_dataloader_idx=add_dataloader_idx, + dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), batch_size=batch_size, + sync_dist=sync_dist and distributed_available(), + sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp, + sync_dist_group=sync_dist_group, metric_attribute=metric_attribute, rank_zero_only=rank_zero_only, ) + self.trainer.logger_connector._current_fx = self._current_fx_name + def log_dict( self, dictionary: Mapping[str, _METRIC_COLLECTION], @@ -388,7 +480,6 @@ def log_dict( sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, batch_size: Optional[int] = None, - metric_attribute: Optional[str] = None, rank_zero_only: Optional[bool] = None, ) -> None: """Log a dictionary of values at once. @@ -415,120 +506,26 @@ def log_dict( each dataloader to not mix values batch_size: Current batch_size. This will be directly inferred from the loaded batch, but some data structures might need to explicitly provide it. - metric_attribute: To restore the metric state, Lightning requires the reference of the - :class:`torchmetrics.Metric` in your model. This is found automatically if it is a model attribute. rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which would produce a deadlock as not all processes would perform this log call. """ - if self.trainer is None: - # not an error to support testing the `*_step` methods without a `Trainer` reference - rank_zero_warn( - "You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet." - " This is most likely because the model hasn't been passed to the `Trainer`" - ) - return - - if ( - self.trainer.training - and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True) - and batch_size is None - ): - raise MisconfigurationException( - "With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided." - ) - - results = self.trainer._results - if results is None: - raise MisconfigurationException( - "You are trying to `self.log()` but the loop `ResultCollection` is not registered" - " yet. This is most likely because you are trying to log in a `predict` hook," - " but it doesn't support logging" - ) - if self._current_fx_name is None: - raise MisconfigurationException( - "You are trying to `self.log()` but it is not managed by the `Trainer` control flow" - ) - - # set the default depending on the fx_name - on_step = self.__auto_choose_log_on_step(on_step) - on_epoch = self.__auto_choose_log_on_epoch(on_epoch) - _FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) - - if isinstance(reduce_fx, str): - reduce_fx = reduce_fx.lower() - - if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name): - # if we started a new epoch (running it's first batch) the hook name has changed - # reset any tensors for the new hook name - results.reset(metrics=False, fx=self._current_fx_name) - - # extract batch size if it's None whenever it's required - if batch_size is None: - if ( - on_epoch - and True in _FxValidator.functions[self._current_fx_name]["on_step"] - and reduce_fx in ("mean", "avg") - ): - batch_size = extract_batch_size(self.trainer._results.current_batch) - else: - batch_size = 1 - - for name, value in dictionary.items(): - # make sure user doesn't introduce logic for multi-dataloaders - if "/dataloader_idx_" in name: - raise MisconfigurationException( - f"You called `self.log` with the key `{name!r}`" - " but it should not contain information about `dataloader_idx`" - ) - - # check for invalid values - apply_to_collection(value, dict, self.__check_not_nested, name) - apply_to_collection( - value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor, dict) - ) - value = apply_to_collection(value, numbers.Number, self.__to_tensor) - - if metric_attribute is None and isinstance(value, Metric): - if self._metric_attributes is None: - # compute once - self._metric_attributes = { - id(module): name for name, module in self.named_modules() if isinstance(module, Metric) - } - if not self._metric_attributes: - raise MisconfigurationException( - "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." - " You can fix this by setting an attribute for the metric in your `LightningModule`." - ) - # try to find the passed metric in the LightningModule - metric_attribute = self._metric_attributes.get(id(value), None) - if metric_attribute is None: - raise MisconfigurationException( - "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." - f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name`" - f" is one of {list(self._metric_attributes.values())}." - ) - - results.log( - self._current_fx_name, - name, - value, + for k, v in dictionary.items(): + self.log( + name=k, + value=v, prog_bar=prog_bar, logger=logger, on_step=on_step, on_epoch=on_epoch, reduce_fx=reduce_fx, enable_graph=enable_graph, - dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), - batch_size=batch_size, - sync_dist=sync_dist and distributed_available(), - sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp, + sync_dist=sync_dist, sync_dist_group=sync_dist_group, - metric_attribute=metric_attribute, + add_dataloader_idx=add_dataloader_idx, + batch_size=batch_size, rank_zero_only=rank_zero_only, ) - self.trainer.logger_connector._current_fx = self._current_fx_name - @staticmethod def __check_not_nested(value: dict, name: str) -> dict: # self-imposed restriction. for simplicity diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 8505dcdce7369..1882695f2a332 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -614,7 +614,7 @@ def training_step(self, batch, batch_idx): trainer = Trainer(default_root_dir=tmpdir) model = TestModel() - with pytest.raises(MisconfigurationException, match="`self.log` with the key `'foo/dataloader_idx_0'`"): + with pytest.raises(MisconfigurationException, match="`self.log` with the key `foo/dataloader_idx_0`"): trainer.fit(model) class TestModel(BoringModel): From 4363e6920138c1c315b91266728a0c191f83bf7a Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 10 Nov 2021 19:34:48 +0530 Subject: [PATCH 14/34] cache batch size --- pytorch_lightning/core/lightning.py | 10 ++++++++-- .../connectors/logger_connector/logger_connector.py | 1 + .../trainer/connectors/logger_connector/result.py | 9 +++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index aae45fba0da9d..5babbc7094255 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -429,8 +429,11 @@ def log( "With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided." ) - if isinstance(reduce_fx, str): - reduce_fx = reduce_fx.lower() + reduce_fx = reduce_fx.lower() if isinstance(reduce_fx, str) else reduce_fx + + # check if we have extracted the batch size already + if batch_size is None: + batch_size = batch_size or results.current_batch_size # extract batch size if it is None and whenever it is required if batch_size is None: @@ -443,6 +446,9 @@ def log( batch_size = extract_batch_size(self.trainer._results.current_batch) except RecursionError: batch_size = 1 + + # cache batch_size + results.current_batch_size = batch_size else: batch_size = 1 diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index a31b0ca01c42a..43a55179e952f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -283,6 +283,7 @@ def on_batch_end(self) -> None: self._logged_metrics.update(metrics["log"]) assert self.trainer._results is not None self.trainer._results.current_batch = None + self.trainer._results.current_batch_size = None def should_reset_tensors(self, fx: str) -> bool: is_different_fx = self._current_fx != fx diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index eaf4bfd904f86..9f9905c1a1523 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -356,6 +356,7 @@ def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = super().__init__() self.training = training self._current_batch = None + self._current_batch_size = None self.device: Optional[Union[str, torch.device]] = device @property @@ -369,6 +370,14 @@ def append_fn(v: ResultMetric) -> None: apply_to_collection(list(self.values()), ResultMetric, append_fn) return o + @property + def current_batch_size(self) -> Optional[int]: + return self._current_batch_size + + @current_batch_size.setter + def current_batch_size(self, val: int) -> None: + self._current_batch_size = val + @property def current_batch(self) -> Any: return self._current_batch From cb9faae05068465bf554dcc6a9003e7d9f110850 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 11 Nov 2021 00:11:36 +0530 Subject: [PATCH 15/34] update test --- tests/loops/test_loop_state_dict.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index a4f9e1824a500..4d1bd57f66ad6 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -80,12 +80,14 @@ def test_loops_state_dict_structure(): }, "epoch_loop.val_loop._results": { "_current_batch": None, + "_current_batch_size": None, "training": False, "device": None, "items": {}, }, "epoch_loop._results": { "_current_batch": None, + "_current_batch_size": None, "training": True, "device": None, "items": {}, @@ -106,6 +108,7 @@ def test_loops_state_dict_structure(): }, "_results": { "_current_batch": None, + "_current_batch_size": None, "training": False, "device": None, "items": {}, @@ -122,6 +125,7 @@ def test_loops_state_dict_structure(): }, "_results": { "_current_batch": None, + "_current_batch_size": None, "training": False, "device": None, "items": {}, From 61b948395db4531c0ab61dad41e6567f56984081 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 11 Nov 2021 00:15:27 +0530 Subject: [PATCH 16/34] mypy --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index dac4b25c46d17..233b8e46fa797 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -395,7 +395,7 @@ def current_batch_size(self) -> Optional[int]: return self._current_batch_size @current_batch_size.setter - def current_batch_size(self, val: int) -> None: + def current_batch_size(self, val: Optional[int]) -> None: self._current_batch_size = val @property From 35c37ac431fff81afd9382632ce2e41c2575bf43 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 11 Nov 2021 00:21:25 +0530 Subject: [PATCH 17/34] mypy --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 233b8e46fa797..b510fd1035fcf 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -376,7 +376,7 @@ def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = super().__init__() self.training = training self._current_batch = None - self._current_batch_size = None + self._current_batch_size: Optional[int] = None self.device: Optional[Union[str, torch.device]] = device @property From 46c22be2e91519b4d6879c92511ecd1c6ac1e822 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 11 Nov 2021 00:57:14 +0530 Subject: [PATCH 18/34] move to resultcollection --- pytorch_lightning/core/lightning.py | 24 ------------------- .../logger_connector/logger_connector.py | 1 + .../connectors/logger_connector/result.py | 22 ++++++++++++++++- 3 files changed, 22 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 5babbc7094255..7867211badb35 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -46,7 +46,6 @@ ) from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem -from pytorch_lightning.utilities.data import extract_batch_size from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import get_model_size_mb @@ -429,29 +428,6 @@ def log( "With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided." ) - reduce_fx = reduce_fx.lower() if isinstance(reduce_fx, str) else reduce_fx - - # check if we have extracted the batch size already - if batch_size is None: - batch_size = batch_size or results.current_batch_size - - # extract batch size if it is None and whenever it is required - if batch_size is None: - if ( - on_epoch - and True in _FxValidator.functions[self._current_fx_name]["on_step"] - and reduce_fx in ("mean", "avg") - ): - try: - batch_size = extract_batch_size(self.trainer._results.current_batch) - except RecursionError: - batch_size = 1 - - # cache batch_size - results.current_batch_size = batch_size - else: - batch_size = 1 - results.log( self._current_fx_name, name, diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 43a55179e952f..7f991fb85ce73 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -260,6 +260,7 @@ def on_batch_start(self, batch_idx: int, batch: Any) -> None: self._epoch_end_reached = False assert self.trainer._results is not None self.trainer._results.current_batch = batch + self.trainer._results.current_batch_size = None def epoch_end_reached(self) -> None: self._epoch_end_reached = True diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index b510fd1035fcf..60bd2c20966fc 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -21,8 +21,10 @@ from typing_extensions import TypedDict from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin +from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device +from pytorch_lightning.utilities.data import extract_batch_size from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.metrics import metrics_to_scalars @@ -421,7 +423,7 @@ def log( sync_dist_fn: Callable = _Sync.no_op, sync_dist_group: Optional[Any] = None, dataloader_idx: Optional[int] = None, - batch_size: int = 1, + batch_size: Optional[int] = None, metric_attribute: Optional[str] = None, rank_zero_only: bool = False, ) -> None: @@ -465,7 +467,25 @@ def log( f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed" ) + # check if we have extracted the batch size already + if batch_size is None: + batch_size = batch_size or self.current_batch_size + + # extract batch size if it is None and whenever it is required + if batch_size is None: + if on_epoch and (True in _FxValidator.functions[fx.split(".")[0]]["on_step"]) and meta.is_mean_reduction: + try: + batch_size = extract_batch_size(self.current_batch) + except RecursionError: + batch_size = 1 + + # cache batch_size + self.current_batch_size = batch_size + else: + batch_size = 1 + batch_size = torch.tensor(batch_size, device=self.device) + self.update_metrics(key, value, batch_size) def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None: From 9b25ecfe4be281e0a36e6acce18e1a688deba59b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 11 Nov 2021 01:23:31 +0530 Subject: [PATCH 19/34] mypy --- .../trainer/connectors/logger_connector/result.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 60bd2c20966fc..c3757cef740a0 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -473,7 +473,9 @@ def log( # extract batch size if it is None and whenever it is required if batch_size is None: - if on_epoch and (True in _FxValidator.functions[fx.split(".")[0]]["on_step"]) and meta.is_mean_reduction: + fx_validate = _FxValidator.functions[fx.split(".")[0]] + assert fx_validate is not None + if on_epoch and (True in fx_validate["on_step"]) and meta.is_mean_reduction: try: batch_size = extract_batch_size(self.current_batch) except RecursionError: From b7a2296de3c920afbc038caad530c4dbe2f4b17c Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 11 Nov 2021 01:47:04 +0530 Subject: [PATCH 20/34] update logic --- .../trainer/connectors/logger_connector/result.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index c3757cef740a0..722d2438964dd 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -473,9 +473,8 @@ def log( # extract batch size if it is None and whenever it is required if batch_size is None: - fx_validate = _FxValidator.functions[fx.split(".")[0]] - assert fx_validate is not None - if on_epoch and (True in fx_validate["on_step"]) and meta.is_mean_reduction: + fx_validate = _FxValidator.functions.get(fx.split(".")[0]) + if on_epoch and fx_validate is not None and (True in fx_validate["on_step"]) and meta.is_mean_reduction: try: batch_size = extract_batch_size(self.current_batch) except RecursionError: From 55a189e0b13091626dfc6d83fd734eb87d76be2b Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 18 Nov 2021 13:04:08 +0530 Subject: [PATCH 21/34] Apply suggestions from code review --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e112dd1d8bd5..ae0bc5043cb3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -121,7 +121,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374)) -- Fixed the call to `extract_batch_size` only when its required ([#10408](https://github.com/PyTorchLightning/pytorch-lightning/pull/10408)) +- Fixed an issue that caused Lightning to extract the batch size even though it was set by the user in `LightningModule.log` ([#10408](https://github.com/PyTorchLightning/pytorch-lightning/pull/10408)) - From c9a854337c3a629181e0fc466c47d427e20fdc07 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 18 Nov 2021 13:06:36 +0530 Subject: [PATCH 22/34] chlog --- CHANGELOG.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d0ae26ef06238..7780008c6a60f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -142,10 +142,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - -- - - - Fixed an issue that caused Lightning to extract the batch size even though it was set by the user in `LightningModule.log` ([#10408](https://github.com/PyTorchLightning/pytorch-lightning/pull/10408)) From a053dd169b8f24b33963b4266c75bbc0d4686a9d Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 18 Nov 2021 15:53:39 +0530 Subject: [PATCH 23/34] update on comments --- .../logger_connector/logger_connector.py | 4 ++ .../connectors/logger_connector/result.py | 48 ++++++++++--------- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 7f991fb85ce73..2cd1a51ad043e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -259,6 +259,8 @@ def on_batch_start(self, batch_idx: int, batch: Any) -> None: self._batch_idx = batch_idx self._epoch_end_reached = False assert self.trainer._results is not None + + # attach reference to the new batch and remove the cached batch_size self.trainer._results.current_batch = batch self.trainer._results.current_batch_size = None @@ -283,6 +285,8 @@ def on_batch_end(self) -> None: self._callback_metrics.update(metrics["callback"]) self._logged_metrics.update(metrics["log"]) assert self.trainer._results is not None + + # drop the reference to current batch and batch_size self.trainer._results.current_batch = None self.trainer._results.current_batch_size = None diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 722d2438964dd..f7645aab4c64c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -212,7 +212,7 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: if self.meta.is_mean_reduction: self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float), dist_reduce_fx=torch.sum) - def update(self, value: _IN_METRIC, batch_size: torch.Tensor) -> None: + def update(self, value: _IN_METRIC, batch_size: int) -> None: if self.is_tensor: value = value.float() if self.meta.on_step: @@ -251,7 +251,7 @@ def reset(self) -> None: self.value.reset() self.has_reset = True - def forward(self, value: _IN_METRIC, batch_size: torch.Tensor) -> None: + def forward(self, value: _IN_METRIC, batch_size: int) -> None: if self.meta.enable_graph: with torch.no_grad(): self.update(value, batch_size) @@ -408,6 +408,27 @@ def current_batch(self) -> Any: def current_batch(self, data: Any) -> None: self._current_batch = data + def _extract_batch_size(self, batch_size: Optional[int], on_epoch: bool, fx: str, meta: _Metadata) -> int: + # check if we have extracted the batch size already + if batch_size is None: + batch_size = batch_size or self.current_batch_size + + # extract batch size if it is None and whenever it is required + if batch_size is None: + fx_validate = _FxValidator.functions.get(fx.split(".")[0]) + if on_epoch and fx_validate is not None and (True in fx_validate["on_step"]) and meta.is_mean_reduction: + try: + batch_size = extract_batch_size(self.current_batch) + except RecursionError: + batch_size = 1 + + # cache batch_size + self.current_batch_size = batch_size + else: + batch_size = 1 + + return batch_size + def log( self, fx: str, @@ -467,26 +488,7 @@ def log( f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed" ) - # check if we have extracted the batch size already - if batch_size is None: - batch_size = batch_size or self.current_batch_size - - # extract batch size if it is None and whenever it is required - if batch_size is None: - fx_validate = _FxValidator.functions.get(fx.split(".")[0]) - if on_epoch and fx_validate is not None and (True in fx_validate["on_step"]) and meta.is_mean_reduction: - try: - batch_size = extract_batch_size(self.current_batch) - except RecursionError: - batch_size = 1 - - # cache batch_size - self.current_batch_size = batch_size - else: - batch_size = 1 - - batch_size = torch.tensor(batch_size, device=self.device) - + batch_size = self._extract_batch_size(batch_size, on_epoch, fx, meta) self.update_metrics(key, value, batch_size) def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None: @@ -504,7 +506,7 @@ def fn(v: _IN_METRIC) -> ResultMetric: value = ResultMetricCollection(value) self[key] = value - def update_metrics(self, key: str, value: _METRIC_COLLECTION, batch_size: torch.Tensor) -> None: + def update_metrics(self, key: str, value: _METRIC_COLLECTION, batch_size: int) -> None: def fn(result_metric: ResultMetric, v: ResultMetric) -> None: # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl` result_metric.forward(v.to(self.device), batch_size) From a61dad1f6dffb8ad8dd195e0284c16b2ae455ccf Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 19 Nov 2021 05:41:24 +0100 Subject: [PATCH 24/34] Use our utilities --- tests/deprecated_api/__init__.py | 28 ++++++++++++++----- .../logging_/test_train_loop_logging.py | 13 ++++----- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/tests/deprecated_api/__init__.py b/tests/deprecated_api/__init__.py index 1026981f75307..91c7ef1c1f880 100644 --- a/tests/deprecated_api/__init__.py +++ b/tests/deprecated_api/__init__.py @@ -14,7 +14,7 @@ """Test deprecated functionality which will be removed in vX.Y.Z.""" import sys from contextlib import contextmanager -from typing import Optional +from typing import Optional, Type import pytest @@ -26,14 +26,28 @@ def _soft_unimport_module(str_module): @contextmanager -def no_deprecated_call(match: Optional[str] = None): +def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None): with pytest.warns(None) as record: yield + + if match is None: try: - w = record.pop(DeprecationWarning) - if match is not None and match not in str(w.message): - return + w = record.pop(expected_warning) except AssertionError: - # no DeprecationWarning raised + # no warning raised + return + else: + for w in record.list: + if w.category is expected_warning and match in w.message.args[0]: + break + else: return - raise AssertionError(f"`DeprecationWarning` was raised: {w}") + + msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`" + raise AssertionError(f"{msg} was raised: {w}") + + +@contextmanager +def no_deprecated_call(match: Optional[str] = None): + with no_warning_call(expected_warning=DeprecationWarning, match=match): + yield diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 3227a5f838f77..0ec61358d9408 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -27,6 +27,7 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.deprecated_api import no_warning_call from tests.helpers.boring_model import BoringModel, RandomDataset, RandomDictDataset from tests.helpers.runif import RunIf @@ -748,15 +749,15 @@ def validation_epoch_end(self, *_) -> None: def test_no_batch_size_extraction_with_specifying_explictly(tmpdir): - batch_size = BoringModel().train_dataloader().batch_size + 10 + batch_size = BoringModel().train_dataloader().batch_size + 1 fast_dev_run = 2 - log_val = 7.0 + log_val = 7 class CustomBoringModel(BoringModel): def on_before_batch_transfer(self, batch, *args, **kwargs): # This is an ambiguous batch which have multiple potential batch sizes if self.trainer.training: - batch = {"batch1": torch.randn(batch_size + 10, 10), "batch2": batch} + batch = {"batch1": torch.randn(batch_size, 10), "batch2": batch} return batch def training_step(self, batch, batch_idx): @@ -765,7 +766,7 @@ def training_step(self, batch, batch_idx): self.log("epoch_sum_log_val", log_val, on_epoch=True, reduce_fx="sum") return super().training_step(batch["batch2"], batch_idx) - def training_epoch_end(self, *args, **kwargs): + def on_train_epoch_end(self, *args, **kwargs): results = self.trainer._results assert results["training_step.step_log_val"].value == log_val assert results["training_step.step_log_val"].cumulated_batch_size == 0 @@ -776,7 +777,5 @@ def training_epoch_end(self, *args, **kwargs): model = CustomBoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=fast_dev_run) - with pytest.warns(None) as record: + with no_warning_call(match="Trying to infer the `batch_size`"): trainer.fit(model) - - assert not any("Trying to infer the `batch_size`" in warn.message.args[0] for warn in record.list) From b9d2a56fbb65ecf9c4fd6a1fcca482ddd9d6b3c3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 19 Nov 2021 05:46:14 +0100 Subject: [PATCH 25/34] Whitespace --- .../trainer/connectors/logger_connector/logger_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 2cd1a51ad043e..0a7db697b431f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -258,8 +258,8 @@ def on_epoch_start(self) -> None: def on_batch_start(self, batch_idx: int, batch: Any) -> None: self._batch_idx = batch_idx self._epoch_end_reached = False - assert self.trainer._results is not None + assert self.trainer._results is not None # attach reference to the new batch and remove the cached batch_size self.trainer._results.current_batch = batch self.trainer._results.current_batch_size = None @@ -284,8 +284,8 @@ def on_batch_end(self) -> None: self._progress_bar_metrics.update(metrics["pbar"]) self._callback_metrics.update(metrics["callback"]) self._logged_metrics.update(metrics["log"]) - assert self.trainer._results is not None + assert self.trainer._results is not None # drop the reference to current batch and batch_size self.trainer._results.current_batch = None self.trainer._results.current_batch_size = None From f87d2159e15391b3c5a1a7a5835388cbc15f696a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 19 Nov 2021 05:49:21 +0100 Subject: [PATCH 26/34] Remove unnecesary properties --- .../connectors/logger_connector/result.py | 20 ++----------------- 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index f7645aab4c64c..a04bec1ee30a2 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -377,9 +377,9 @@ class ResultCollection(dict): def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None: super().__init__() self.training = training - self._current_batch = None - self._current_batch_size: Optional[int] = None self.device: Optional[Union[str, torch.device]] = device + self.current_batch: Optional[Any] = None + self.current_batch_size: Optional[int] = None @property def result_metrics(self) -> List[ResultMetric]: @@ -392,22 +392,6 @@ def append_fn(v: ResultMetric) -> None: apply_to_collection(list(self.values()), ResultMetric, append_fn) return o - @property - def current_batch_size(self) -> Optional[int]: - return self._current_batch_size - - @current_batch_size.setter - def current_batch_size(self, val: Optional[int]) -> None: - self._current_batch_size = val - - @property - def current_batch(self) -> Any: - return self._current_batch - - @current_batch.setter - def current_batch(self, data: Any) -> None: - self._current_batch = data - def _extract_batch_size(self, batch_size: Optional[int], on_epoch: bool, fx: str, meta: _Metadata) -> int: # check if we have extracted the batch size already if batch_size is None: From 0074b8149645757e69a1a95fe92b11fe0dc2a025 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 19 Nov 2021 05:50:03 +0100 Subject: [PATCH 27/34] Remove current prefix --- .../connectors/logger_connector/logger_connector.py | 8 ++++---- .../trainer/connectors/logger_connector/result.py | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 0a7db697b431f..4b56aefb9809f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -261,8 +261,8 @@ def on_batch_start(self, batch_idx: int, batch: Any) -> None: assert self.trainer._results is not None # attach reference to the new batch and remove the cached batch_size - self.trainer._results.current_batch = batch - self.trainer._results.current_batch_size = None + self.trainer._results.batch = batch + self.trainer._results.batch_size = None def epoch_end_reached(self) -> None: self._epoch_end_reached = True @@ -287,8 +287,8 @@ def on_batch_end(self) -> None: assert self.trainer._results is not None # drop the reference to current batch and batch_size - self.trainer._results.current_batch = None - self.trainer._results.current_batch_size = None + self.trainer._results.batch = None + self.trainer._results.batch_size = None def should_reset_tensors(self, fx: str) -> bool: is_different_fx = self._current_fx != fx diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index a04bec1ee30a2..06f749454c31e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -378,8 +378,8 @@ def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = super().__init__() self.training = training self.device: Optional[Union[str, torch.device]] = device - self.current_batch: Optional[Any] = None - self.current_batch_size: Optional[int] = None + self.batch: Optional[Any] = None + self.batch_size: Optional[int] = None @property def result_metrics(self) -> List[ResultMetric]: @@ -395,19 +395,19 @@ def append_fn(v: ResultMetric) -> None: def _extract_batch_size(self, batch_size: Optional[int], on_epoch: bool, fx: str, meta: _Metadata) -> int: # check if we have extracted the batch size already if batch_size is None: - batch_size = batch_size or self.current_batch_size + batch_size = batch_size or self.batch_size # extract batch size if it is None and whenever it is required if batch_size is None: fx_validate = _FxValidator.functions.get(fx.split(".")[0]) if on_epoch and fx_validate is not None and (True in fx_validate["on_step"]) and meta.is_mean_reduction: try: - batch_size = extract_batch_size(self.current_batch) + batch_size = extract_batch_size(self.batch) except RecursionError: batch_size = 1 # cache batch_size - self.current_batch_size = batch_size + self.batch_size = batch_size else: batch_size = 1 From cb9de15e9e35d05b6570b0cd0aaca7b0ae4c3ab8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 19 Nov 2021 05:52:10 +0100 Subject: [PATCH 28/34] Simplify arguments --- .../trainer/connectors/logger_connector/result.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 06f749454c31e..bb75991cdb308 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -392,15 +392,20 @@ def append_fn(v: ResultMetric) -> None: apply_to_collection(list(self.values()), ResultMetric, append_fn) return o - def _extract_batch_size(self, batch_size: Optional[int], on_epoch: bool, fx: str, meta: _Metadata) -> int: + def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int: # check if we have extracted the batch size already if batch_size is None: batch_size = batch_size or self.batch_size # extract batch size if it is None and whenever it is required if batch_size is None: - fx_validate = _FxValidator.functions.get(fx.split(".")[0]) - if on_epoch and fx_validate is not None and (True in fx_validate["on_step"]) and meta.is_mean_reduction: + fx_validate = _FxValidator.functions.get(meta.fx.split(".")[0]) + if ( + meta.on_epoch + and fx_validate is not None + and (True in fx_validate["on_step"]) + and meta.is_mean_reduction + ): try: batch_size = extract_batch_size(self.batch) except RecursionError: @@ -472,7 +477,7 @@ def log( f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed" ) - batch_size = self._extract_batch_size(batch_size, on_epoch, fx, meta) + batch_size = self._extract_batch_size(batch_size, meta) self.update_metrics(key, value, batch_size) def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None: From 92cd293817472d90823c6e1c0fe4ef3910c18da7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 19 Nov 2021 05:57:40 +0100 Subject: [PATCH 29/34] Avoid indentations --- .../connectors/logger_connector/result.py | 34 ++++++++----------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index bb75991cdb308..ac843c7703865 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -394,28 +394,22 @@ def append_fn(v: ResultMetric) -> None: def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int: # check if we have extracted the batch size already - if batch_size is None: - batch_size = batch_size or self.batch_size - - # extract batch size if it is None and whenever it is required - if batch_size is None: - fx_validate = _FxValidator.functions.get(meta.fx.split(".")[0]) - if ( - meta.on_epoch - and fx_validate is not None - and (True in fx_validate["on_step"]) - and meta.is_mean_reduction - ): - try: - batch_size = extract_batch_size(self.batch) - except RecursionError: - batch_size = 1 - - # cache batch_size - self.batch_size = batch_size - else: + batch_size = batch_size or self.batch_size + if batch_size is not None: + return batch_size + + # extract it + fx_validate = _FxValidator.functions.get(meta.fx.split(".")[0]) + if meta.on_epoch and fx_validate is not None and (True in fx_validate["on_step"]) and meta.is_mean_reduction: + try: + batch_size = extract_batch_size(self.batch) + except RecursionError: batch_size = 1 + # cache batch_size + self.batch_size = batch_size + else: + batch_size = 1 return batch_size def log( From 6c9522d2d0e971490739e84ffaf2372c5e9a01d5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 19 Nov 2021 06:01:17 +0100 Subject: [PATCH 30/34] Cache only if succesfully extracted --- .../trainer/connectors/logger_connector/result.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index ac843c7703865..4499292b34281 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -397,19 +397,16 @@ def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int batch_size = batch_size or self.batch_size if batch_size is not None: return batch_size + batch_size = 1 # extract it fx_validate = _FxValidator.functions.get(meta.fx.split(".")[0]) if meta.on_epoch and fx_validate is not None and (True in fx_validate["on_step"]) and meta.is_mean_reduction: try: batch_size = extract_batch_size(self.batch) + self.batch_size = batch_size except RecursionError: - batch_size = 1 - - # cache batch_size - self.batch_size = batch_size - else: - batch_size = 1 + pass return batch_size def log( From bb6f3c4785edcfa150237f7d98692df5c762f01c Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 19 Nov 2021 15:21:41 +0530 Subject: [PATCH 31/34] minor updates --- .../connectors/logger_connector/result.py | 9 ++++++--- pytorch_lightning/utilities/data.py | 5 ++++- tests/loops/test_loop_state_dict.py | 16 ++++++++-------- tests/utilities/test_data.py | 7 +++++-- 4 files changed, 23 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 4499292b34281..55ea62e2a904f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -394,19 +394,22 @@ def append_fn(v: ResultMetric) -> None: def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int: # check if we have extracted the batch size already - batch_size = batch_size or self.batch_size + if batch_size is None: + batch_size = self.batch_size + if batch_size is not None: return batch_size - batch_size = 1 - # extract it + batch_size = 1 fx_validate = _FxValidator.functions.get(meta.fx.split(".")[0]) if meta.on_epoch and fx_validate is not None and (True in fx_validate["on_step"]) and meta.is_mean_reduction: try: + # extract it batch_size = extract_batch_size(self.batch) self.batch_size = batch_size except RecursionError: pass + return batch_size def log( diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index a75afa775848b..e6cfdcd953e61 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -29,7 +29,10 @@ def _extract_batch_size(batch: BType) -> Generator[int, None, None]: if isinstance(batch, torch.Tensor): - yield batch.size(0) + if batch.ndim == 0: + yield 1 + else: + yield batch.size(0) elif isinstance(batch, str): yield len(batch) elif isinstance(batch, (Iterable, Mapping)): diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 4d1bd57f66ad6..72eeb197e9e57 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -79,15 +79,15 @@ def test_loops_state_dict_structure(): "is_last_batch": False, }, "epoch_loop.val_loop._results": { - "_current_batch": None, - "_current_batch_size": None, + "batch": None, + "batch_size": None, "training": False, "device": None, "items": {}, }, "epoch_loop._results": { - "_current_batch": None, - "_current_batch_size": None, + "batch": None, + "batch_size": None, "training": True, "device": None, "items": {}, @@ -107,8 +107,8 @@ def test_loops_state_dict_structure(): "is_last_batch": False, }, "_results": { - "_current_batch": None, - "_current_batch_size": None, + "batch": None, + "batch_size": None, "training": False, "device": None, "items": {}, @@ -124,8 +124,8 @@ def test_loops_state_dict_structure(): "is_last_batch": False, }, "_results": { - "_current_batch": None, - "_current_batch_size": None, + "batch": None, + "batch_size": None, "training": False, "device": None, "items": {}, diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index acbe645515f55..f4c61cda64f5d 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -12,6 +12,7 @@ warning_cache, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.deprecated_api import no_warning_call from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset @@ -19,9 +20,8 @@ def test_extract_batch_size(): """Tests the behavior of extracting the batch size.""" def _check_warning_not_raised(data, expected): - with pytest.warns(None) as record: + with no_warning_call(match="Trying to infer the `batch_size`"): assert extract_batch_size(data) == expected - assert len(record) == 0 def _check_warning_raised(data, expected): with pytest.warns(UserWarning, match=f"Trying to infer the `batch_size` .* we found is {expected}."): @@ -43,6 +43,9 @@ def _check_warning_raised(data, expected): batch = {"test": [{"test": [torch.zeros(11, 10)]}]} _check_warning_not_raised(batch, 11) + batch = {"a": [torch.tensor(1), torch.tensor(2)], "b": torch.tensor([1, 2, 3, 4])} + _check_warning_raised(batch, 1) + batch = {"test": [{"test": [torch.zeros(11, 10), torch.zeros(10, 10)]}]} _check_warning_raised(batch, 11) From 250defaa11f10dc980158bc12fdc690a60f6d6c9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 19 Nov 2021 15:15:41 +0100 Subject: [PATCH 32/34] Simplify check --- .../trainer/connectors/logger_connector/result.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 55ea62e2a904f..04e759a6217a5 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -21,7 +21,6 @@ from typing_extensions import TypedDict from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin -from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device from pytorch_lightning.utilities.data import extract_batch_size @@ -401,8 +400,7 @@ def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int return batch_size batch_size = 1 - fx_validate = _FxValidator.functions.get(meta.fx.split(".")[0]) - if meta.on_epoch and fx_validate is not None and (True in fx_validate["on_step"]) and meta.is_mean_reduction: + if self.batch is not None and meta.on_epoch and meta.is_mean_reduction: try: # extract it batch_size = extract_batch_size(self.batch) From 1ee9b1731bf7b1df982f0a1377751faf9eddc0a5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 19 Nov 2021 15:15:57 +0100 Subject: [PATCH 33/34] Remove silly comment --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 04e759a6217a5..ac0a460239ec0 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -402,7 +402,6 @@ def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int batch_size = 1 if self.batch is not None and meta.on_epoch and meta.is_mean_reduction: try: - # extract it batch_size = extract_batch_size(self.batch) self.batch_size = batch_size except RecursionError: From 560da623f7788c4da1618ce945250ef0dbac210e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 19 Nov 2021 15:24:56 +0100 Subject: [PATCH 34/34] mypy fix --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index ac0a460239ec0..06da64bf3ed8e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -487,7 +487,7 @@ def fn(v: _IN_METRIC) -> ResultMetric: self[key] = value def update_metrics(self, key: str, value: _METRIC_COLLECTION, batch_size: int) -> None: - def fn(result_metric: ResultMetric, v: ResultMetric) -> None: + def fn(result_metric: ResultMetric, v: torch.Tensor) -> None: # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl` result_metric.forward(v.to(self.device), batch_size) result_metric.has_reset = False