diff --git a/CHANGELOG.md b/CHANGELOG.md index d1eed9848cdbb..670be90f3742c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -196,8 +196,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746)) -- Fixed the default logging level for batch hooks associated with training from `on_step=False, on_epoch=True` to `on_step=True, on_epoch=False` ([#10756](https://github.com/PyTorchLightning/pytorch-lightning/pull/10756)) +- Disabled batch_size extraction for torchmetric instances because they accumulate the metrics internally ([#10815](https://github.com/PyTorchLightning/pytorch-lightning/pull/10815)) + +- Fixed the default logging level for batch hooks associated with training from `on_step=False, on_epoch=True` to `on_step=True, on_epoch=False` ([#10756](https://github.com/PyTorchLightning/pytorch-lightning/pull/10756)) - diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 1c27b75854d96..4c29d30ee1a5a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -338,12 +338,13 @@ class ResultMetricCollection(dict): with the same metadata. """ - def __init__(self, *args: Any) -> None: - super().__init__(*args) - @property def meta(self) -> _Metadata: - return list(self.values())[0].meta + return next(iter(self.values())).meta + + @property + def has_tensor(self) -> bool: + return any(v.is_tensor for v in self.values()) def __getstate__(self, drop_value: bool = False) -> dict: def getstate(item: ResultMetric) -> dict: @@ -403,7 +404,7 @@ def append_fn(v: ResultMetric) -> None: apply_to_collection(list(self.values()), ResultMetric, append_fn) return o - def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int: + def _extract_batch_size(self, value: _METRIC_COLLECTION, batch_size: Optional[int], meta: _Metadata) -> int: # check if we have extracted the batch size already if batch_size is None: batch_size = self.batch_size @@ -412,7 +413,8 @@ def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int return batch_size batch_size = 1 - if self.batch is not None and meta.on_epoch and meta.is_mean_reduction: + is_tensor = value.is_tensor if isinstance(value, ResultMetric) else value.has_tensor + if self.batch is not None and is_tensor and meta.on_epoch and meta.is_mean_reduction: batch_size = extract_batch_size(self.batch) self.batch_size = batch_size @@ -477,7 +479,7 @@ def log( f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed" ) - batch_size = self._extract_batch_size(batch_size, meta) + batch_size = self._extract_batch_size(self[key], batch_size, meta) self.update_metrics(key, value, batch_size) def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None: diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index a3e3c0aec9d34..b3b6667e971af 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -17,7 +17,7 @@ import pytest import torch from torch.utils.data import DataLoader -from torchmetrics import Accuracy, AveragePrecision +from torchmetrics import Accuracy, AveragePrecision, MeanAbsoluteError, MeanSquaredError from pytorch_lightning import LightningModule from pytorch_lightning.callbacks.base import Callback @@ -640,3 +640,51 @@ def training_step(self, batch, batch_idx): # should not get overridden if logged manually assert trainer.logged_metrics == {"epoch": -1} + + +def test_result_collection_batch_size_extraction(): + fx_name = "training_step" + log_val = torch.tensor(7.0) + + results = ResultCollection(training=True, device="cpu") + results.batch = torch.randn(1, 4) + train_mse = MeanSquaredError() + train_mse(torch.randn(4, 5), torch.randn(4, 5)) + results.log(fx_name, "train_logs", {"mse": train_mse, "log_val": log_val}, on_step=False, on_epoch=True) + assert results.batch_size == 1 + assert isinstance(results["training_step.train_logs"]["mse"].value, MeanSquaredError) + assert results["training_step.train_logs"]["log_val"].value == log_val + + results = ResultCollection(training=True, device="cpu") + results.batch = torch.randn(1, 4) + results.log(fx_name, "train_log", log_val, on_step=False, on_epoch=True) + assert results.batch_size == 1 + assert results["training_step.train_log"].value == log_val + assert results["training_step.train_log"].cumulated_batch_size == 1 + + +def test_result_collection_no_batch_size_extraction(): + results = ResultCollection(training=True, device="cpu") + results.batch = torch.randn(1, 4) + fx_name = "training_step" + batch_size = 10 + log_val = torch.tensor(7.0) + + train_mae = MeanAbsoluteError() + train_mae(torch.randn(4, 5), torch.randn(4, 5)) + train_mse = MeanSquaredError() + train_mse(torch.randn(4, 5), torch.randn(4, 5)) + results.log(fx_name, "step_log_val", log_val, on_step=True, on_epoch=False) + results.log(fx_name, "epoch_log_val", log_val, on_step=False, on_epoch=True, batch_size=batch_size) + results.log(fx_name, "epoch_sum_log_val", log_val, on_step=True, on_epoch=True, reduce_fx="sum") + results.log(fx_name, "train_mae", train_mae, on_step=True, on_epoch=False) + results.log(fx_name, "train_mse", {"mse": train_mse}, on_step=True, on_epoch=False) + + assert results.batch_size is None + assert isinstance(results["training_step.train_mse"]["mse"].value, MeanSquaredError) + assert isinstance(results["training_step.train_mae"].value, MeanAbsoluteError) + assert results["training_step.step_log_val"].value == log_val + assert results["training_step.step_log_val"].cumulated_batch_size == 0 + assert results["training_step.epoch_log_val"].value == log_val * batch_size + assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size + assert results["training_step.epoch_sum_log_val"].value == log_val diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 139714acc97bc..2ad2585f0fe02 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -27,7 +27,6 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.deprecated_api import no_warning_call from tests.helpers.boring_model import BoringModel, RandomDataset, RandomDictDataset from tests.helpers.runif import RunIf @@ -746,36 +745,3 @@ def validation_epoch_end(self, *_) -> None: train_data = DataLoader(RandomDataset(32, 64), batch_size=2) val_data = DataLoader(RandomDataset(32, 64), batch_size=2) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) - - -def test_no_batch_size_extraction_with_specifying_explictly(tmpdir): - batch_size = BoringModel().train_dataloader().batch_size + 1 - fast_dev_run = 2 - log_val = 7 - - class CustomBoringModel(BoringModel): - def on_before_batch_transfer(self, batch, *args, **kwargs): - # This is an ambiguous batch which have multiple potential batch sizes - if self.trainer.training: - batch = {"batch1": torch.randn(batch_size, 10), "batch2": batch} - return batch - - def training_step(self, batch, batch_idx): - self.log("step_log_val", log_val, on_epoch=False) - self.log("epoch_log_val", log_val, batch_size=batch_size, on_step=False, on_epoch=True) - self.log("epoch_sum_log_val", log_val, on_epoch=True, reduce_fx="sum") - return super().training_step(batch["batch2"], batch_idx) - - def on_train_epoch_end(self, *args, **kwargs): - results = self.trainer._results - assert results["training_step.step_log_val"].value == log_val - assert results["training_step.step_log_val"].cumulated_batch_size == 0 - assert results["training_step.epoch_log_val"].value == log_val * batch_size * fast_dev_run - assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size * fast_dev_run - assert results["training_step.epoch_sum_log_val"].value == log_val * fast_dev_run - - model = CustomBoringModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=fast_dev_run) - - with no_warning_call(match="Trying to infer the `batch_size`"): - trainer.fit(model)