From 328ab770acc84e57d46b6bd9267a140b1eaa2e0e Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 29 Nov 2021 20:13:28 +0530 Subject: [PATCH 1/7] disable batch_size extraction for torchmetric instances --- .../connectors/logger_connector/result.py | 11 ++++++-- .../logging_/test_train_loop_logging.py | 28 ++++++++++--------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 1c27b75854d96..079d2d05e504e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -403,7 +403,7 @@ def append_fn(v: ResultMetric) -> None: apply_to_collection(list(self.values()), ResultMetric, append_fn) return o - def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int: + def _extract_batch_size(self, value: _METRIC_COLLECTION, batch_size: Optional[int], meta: _Metadata) -> int: # check if we have extracted the batch size already if batch_size is None: batch_size = self.batch_size @@ -412,7 +412,12 @@ def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int return batch_size batch_size = 1 - if self.batch is not None and meta.on_epoch and meta.is_mean_reduction: + is_tensor = ( + value.is_tensor + if isinstance(value, ResultMetric) + else any(isinstance(val, ResultMetric) for val in value.values()) + ) + if self.batch is not None and is_tensor and meta.on_epoch and meta.is_mean_reduction: batch_size = extract_batch_size(self.batch) self.batch_size = batch_size @@ -477,7 +482,7 @@ def log( f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed" ) - batch_size = self._extract_batch_size(batch_size, meta) + batch_size = self._extract_batch_size(self[key], batch_size, meta) self.update_metrics(key, value, batch_size) def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None: diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 6bfbaa9a7bcb1..d899185b639e2 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -16,18 +16,18 @@ import collections import itertools from re import escape +from unittest import mock import numpy as np import pytest import torch from torch.utils.data import DataLoader -from torchmetrics import Accuracy +from torchmetrics import Accuracy, MeanAbsoluteError from pytorch_lightning import callbacks, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.deprecated_api import no_warning_call from tests.helpers.boring_model import BoringModel, RandomDataset, RandomDictDataset from tests.helpers.runif import RunIf @@ -748,26 +748,29 @@ def validation_epoch_end(self, *_) -> None: trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) -def test_no_batch_size_extraction_with_specifying_explictly(tmpdir): - batch_size = BoringModel().train_dataloader().batch_size + 1 +@mock.patch("pytorch_lightning.utilities.data.extract_batch_size") +def test_no_batch_size_extraction_with_specifying_explictly(mock_method, tmpdir): + batch_size = 10 fast_dev_run = 2 log_val = 7 class CustomBoringModel(BoringModel): - def on_before_batch_transfer(self, batch, *args, **kwargs): - # This is an ambiguous batch which have multiple potential batch sizes - if self.trainer.training: - batch = {"batch1": torch.randn(batch_size, 10), "batch2": batch} - return batch + def __init__(self): + super().__init__() + self.train_mae = MeanAbsoluteError() def training_step(self, batch, batch_idx): self.log("step_log_val", log_val, on_epoch=False) self.log("epoch_log_val", log_val, batch_size=batch_size, on_step=False, on_epoch=True) self.log("epoch_sum_log_val", log_val, on_epoch=True, reduce_fx="sum") - return super().training_step(batch["batch2"], batch_idx) + + self.train_mae(batch, torch.ones_like(batch)) + self.log("train_mae", self.train_mae) + return super().training_step(batch, batch_idx) def on_train_epoch_end(self, *args, **kwargs): results = self.trainer._results + assert isinstance(results["training_step.train_mae"].value, MeanAbsoluteError) assert results["training_step.step_log_val"].value == log_val assert results["training_step.step_log_val"].cumulated_batch_size == 0 assert results["training_step.epoch_log_val"].value == log_val * batch_size * fast_dev_run @@ -776,6 +779,5 @@ def on_train_epoch_end(self, *args, **kwargs): model = CustomBoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=fast_dev_run) - - with no_warning_call(match="Trying to infer the `batch_size`"): - trainer.fit(model) + trainer.fit(model) + mock_method.assert_not_called() From 8d31e91ca47dbcc79cc44f9b83a6b5d7a4dd12a6 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 29 Nov 2021 20:22:42 +0530 Subject: [PATCH 2/7] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bb5d26a1071a6..d00e36a6b06d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -193,6 +193,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746)) +- Disable batch_size extraction for torchmetric instances ([#10815](https://github.com/PyTorchLightning/pytorch-lightning/pull/10815)) + + - From 91349aec33de1f30b965b3ac5e90f19b34abc53b Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 29 Nov 2021 20:25:57 +0530 Subject: [PATCH 3/7] Apply suggestions from code review --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d00e36a6b06d5..7d371475c5ee9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -193,7 +193,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746)) -- Disable batch_size extraction for torchmetric instances ([#10815](https://github.com/PyTorchLightning/pytorch-lightning/pull/10815)) +- Disabled batch_size extraction for torchmetric instances because they accumulate the metrics internally ([#10815](https://github.com/PyTorchLightning/pytorch-lightning/pull/10815)) - From 276ed0b235628f44a1cf54b0202e20c2a098aec5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 29 Nov 2021 16:33:13 +0100 Subject: [PATCH 4/7] Add minor improvements --- .../trainer/connectors/logger_connector/result.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 079d2d05e504e..5de5b36c54883 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -338,12 +338,9 @@ class ResultMetricCollection(dict): with the same metadata. """ - def __init__(self, *args: Any) -> None: - super().__init__(*args) - @property def meta(self) -> _Metadata: - return list(self.values())[0].meta + return next(iter(self.values())).meta def __getstate__(self, drop_value: bool = False) -> dict: def getstate(item: ResultMetric) -> dict: From 1e5c27dec71f54fdecf4eb1d4d52f9880e586255 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 29 Nov 2021 16:36:19 +0100 Subject: [PATCH 5/7] has_tensor property --- .../trainer/connectors/logger_connector/result.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 5de5b36c54883..4c29d30ee1a5a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -342,6 +342,10 @@ class ResultMetricCollection(dict): def meta(self) -> _Metadata: return next(iter(self.values())).meta + @property + def has_tensor(self) -> bool: + return any(v.is_tensor for v in self.values()) + def __getstate__(self, drop_value: bool = False) -> dict: def getstate(item: ResultMetric) -> dict: return item.__getstate__(drop_value=drop_value) @@ -409,11 +413,7 @@ def _extract_batch_size(self, value: _METRIC_COLLECTION, batch_size: Optional[in return batch_size batch_size = 1 - is_tensor = ( - value.is_tensor - if isinstance(value, ResultMetric) - else any(isinstance(val, ResultMetric) for val in value.values()) - ) + is_tensor = value.is_tensor if isinstance(value, ResultMetric) else value.has_tensor if self.batch is not None and is_tensor and meta.on_epoch and meta.is_mean_reduction: batch_size = extract_batch_size(self.batch) self.batch_size = batch_size From ee34f095e82c21dd7fcf5bc51dbfb0414b86d466 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 30 Nov 2021 14:14:02 +0530 Subject: [PATCH 6/7] add more tests --- .../logging_/test_train_loop_logging.py | 34 +++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index d899185b639e2..1469ab838edda 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -22,7 +22,7 @@ import pytest import torch from torch.utils.data import DataLoader -from torchmetrics import Accuracy, MeanAbsoluteError +from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError from pytorch_lightning import callbacks, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar @@ -748,8 +748,34 @@ def validation_epoch_end(self, *_) -> None: trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) +@mock.patch("pytorch_lightning.utilities.data.extract_batch_size", return_value=1) +def test_batch_size_extraction_with_mixed_metrics(mock_method, tmpdir): + fast_dev_run = 2 + log_val = 7 + + class CustomBoringModel(BoringModel): + def __init__(self): + super().__init__() + self.train_mse = MeanSquaredError() + + def training_step(self, batch, batch_idx): + self.train_mse(batch, torch.ones_like(batch)) + self.log("train_logs", {"mse": self.train_mse, "log_val": log_val}, on_step=False, on_epoch=True) + return super().training_step(batch, batch_idx) + + def on_train_epoch_end(self, *args, **kwargs): + results = self.trainer._results["training_step.train_logs"] + assert isinstance(results["mse"].value, MeanSquaredError) + assert results["log_val"].value == log_val * fast_dev_run + + model = CustomBoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=fast_dev_run) + trainer.fit(model) + mock_method.assert_called() + + @mock.patch("pytorch_lightning.utilities.data.extract_batch_size") -def test_no_batch_size_extraction_with_specifying_explictly(mock_method, tmpdir): +def test_no_batch_size_extraction_when_not_required(mock_method, tmpdir): batch_size = 10 fast_dev_run = 2 log_val = 7 @@ -758,6 +784,7 @@ class CustomBoringModel(BoringModel): def __init__(self): super().__init__() self.train_mae = MeanAbsoluteError() + self.train_mse = MeanSquaredError() def training_step(self, batch, batch_idx): self.log("step_log_val", log_val, on_epoch=False) @@ -765,11 +792,14 @@ def training_step(self, batch, batch_idx): self.log("epoch_sum_log_val", log_val, on_epoch=True, reduce_fx="sum") self.train_mae(batch, torch.ones_like(batch)) + self.train_mse(batch, torch.ones_like(batch)) self.log("train_mae", self.train_mae) + self.log("train_mse", {"mse": self.train_mse}) return super().training_step(batch, batch_idx) def on_train_epoch_end(self, *args, **kwargs): results = self.trainer._results + assert isinstance(results["training_step.train_mse"]["mse"].value, MeanSquaredError) assert isinstance(results["training_step.train_mae"].value, MeanAbsoluteError) assert results["training_step.step_log_val"].value == log_val assert results["training_step.step_log_val"].cumulated_batch_size == 0 From 167f4d574f6964c7ab341fc3a640bec712bdd170 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 30 Nov 2021 15:43:43 +0530 Subject: [PATCH 7/7] convert to unittesting --- .../trainer/logging_/test_logger_connector.py | 50 +++++++++++++- .../logging_/test_train_loop_logging.py | 68 +------------------ 2 files changed, 50 insertions(+), 68 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 6e5c6d2ddbe1b..41f99bf50d5a9 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -17,7 +17,7 @@ import pytest import torch from torch.utils.data import DataLoader -from torchmetrics import Accuracy, AveragePrecision +from torchmetrics import Accuracy, AveragePrecision, MeanAbsoluteError, MeanSquaredError from pytorch_lightning import LightningModule from pytorch_lightning.callbacks.base import Callback @@ -640,3 +640,51 @@ def training_step(self, batch, batch_idx): # should not get overridden if logged manually assert trainer.logged_metrics == {"epoch": -1} + + +def test_result_collection_batch_size_extraction(): + fx_name = "training_step" + log_val = torch.tensor(7.0) + + results = ResultCollection(training=True, device="cpu") + results.batch = torch.randn(1, 4) + train_mse = MeanSquaredError() + train_mse(torch.randn(4, 5), torch.randn(4, 5)) + results.log(fx_name, "train_logs", {"mse": train_mse, "log_val": log_val}, on_step=False, on_epoch=True) + assert results.batch_size == 1 + assert isinstance(results["training_step.train_logs"]["mse"].value, MeanSquaredError) + assert results["training_step.train_logs"]["log_val"].value == log_val + + results = ResultCollection(training=True, device="cpu") + results.batch = torch.randn(1, 4) + results.log(fx_name, "train_log", log_val, on_step=False, on_epoch=True) + assert results.batch_size == 1 + assert results["training_step.train_log"].value == log_val + assert results["training_step.train_log"].cumulated_batch_size == 1 + + +def test_result_collection_no_batch_size_extraction(): + results = ResultCollection(training=True, device="cpu") + results.batch = torch.randn(1, 4) + fx_name = "training_step" + batch_size = 10 + log_val = torch.tensor(7.0) + + train_mae = MeanAbsoluteError() + train_mae(torch.randn(4, 5), torch.randn(4, 5)) + train_mse = MeanSquaredError() + train_mse(torch.randn(4, 5), torch.randn(4, 5)) + results.log(fx_name, "step_log_val", log_val, on_step=True, on_epoch=False) + results.log(fx_name, "epoch_log_val", log_val, on_step=False, on_epoch=True, batch_size=batch_size) + results.log(fx_name, "epoch_sum_log_val", log_val, on_step=True, on_epoch=True, reduce_fx="sum") + results.log(fx_name, "train_mae", train_mae, on_step=True, on_epoch=False) + results.log(fx_name, "train_mse", {"mse": train_mse}, on_step=True, on_epoch=False) + + assert results.batch_size is None + assert isinstance(results["training_step.train_mse"]["mse"].value, MeanSquaredError) + assert isinstance(results["training_step.train_mae"].value, MeanAbsoluteError) + assert results["training_step.step_log_val"].value == log_val + assert results["training_step.step_log_val"].cumulated_batch_size == 0 + assert results["training_step.epoch_log_val"].value == log_val * batch_size + assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size + assert results["training_step.epoch_sum_log_val"].value == log_val diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 1469ab838edda..e89d6b3f614eb 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -16,13 +16,12 @@ import collections import itertools from re import escape -from unittest import mock import numpy as np import pytest import torch from torch.utils.data import DataLoader -from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError +from torchmetrics import Accuracy from pytorch_lightning import callbacks, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar @@ -746,68 +745,3 @@ def validation_epoch_end(self, *_) -> None: train_data = DataLoader(RandomDataset(32, 64), batch_size=2) val_data = DataLoader(RandomDataset(32, 64), batch_size=2) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) - - -@mock.patch("pytorch_lightning.utilities.data.extract_batch_size", return_value=1) -def test_batch_size_extraction_with_mixed_metrics(mock_method, tmpdir): - fast_dev_run = 2 - log_val = 7 - - class CustomBoringModel(BoringModel): - def __init__(self): - super().__init__() - self.train_mse = MeanSquaredError() - - def training_step(self, batch, batch_idx): - self.train_mse(batch, torch.ones_like(batch)) - self.log("train_logs", {"mse": self.train_mse, "log_val": log_val}, on_step=False, on_epoch=True) - return super().training_step(batch, batch_idx) - - def on_train_epoch_end(self, *args, **kwargs): - results = self.trainer._results["training_step.train_logs"] - assert isinstance(results["mse"].value, MeanSquaredError) - assert results["log_val"].value == log_val * fast_dev_run - - model = CustomBoringModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=fast_dev_run) - trainer.fit(model) - mock_method.assert_called() - - -@mock.patch("pytorch_lightning.utilities.data.extract_batch_size") -def test_no_batch_size_extraction_when_not_required(mock_method, tmpdir): - batch_size = 10 - fast_dev_run = 2 - log_val = 7 - - class CustomBoringModel(BoringModel): - def __init__(self): - super().__init__() - self.train_mae = MeanAbsoluteError() - self.train_mse = MeanSquaredError() - - def training_step(self, batch, batch_idx): - self.log("step_log_val", log_val, on_epoch=False) - self.log("epoch_log_val", log_val, batch_size=batch_size, on_step=False, on_epoch=True) - self.log("epoch_sum_log_val", log_val, on_epoch=True, reduce_fx="sum") - - self.train_mae(batch, torch.ones_like(batch)) - self.train_mse(batch, torch.ones_like(batch)) - self.log("train_mae", self.train_mae) - self.log("train_mse", {"mse": self.train_mse}) - return super().training_step(batch, batch_idx) - - def on_train_epoch_end(self, *args, **kwargs): - results = self.trainer._results - assert isinstance(results["training_step.train_mse"]["mse"].value, MeanSquaredError) - assert isinstance(results["training_step.train_mae"].value, MeanAbsoluteError) - assert results["training_step.step_log_val"].value == log_val - assert results["training_step.step_log_val"].cumulated_batch_size == 0 - assert results["training_step.epoch_log_val"].value == log_val * batch_size * fast_dev_run - assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size * fast_dev_run - assert results["training_step.epoch_sum_log_val"].value == log_val * fast_dev_run - - model = CustomBoringModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=fast_dev_run) - trainer.fit(model) - mock_method.assert_not_called()