Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable batch_size extraction for torchmetric instances #10815

Merged
merged 9 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746))


- Fixed the default logging level for batch hooks associated with training from `on_step=False, on_epoch=True` to `on_step=True, on_epoch=False` ([#10756](https://github.com/PyTorchLightning/pytorch-lightning/pull/10756))
- Disabled batch_size extraction for torchmetric instances because they accumulate the metrics internally ([#10815](https://github.com/PyTorchLightning/pytorch-lightning/pull/10815))


- Fixed the default logging level for batch hooks associated with training from `on_step=False, on_epoch=True` to `on_step=True, on_epoch=False` ([#10756](https://github.com/PyTorchLightning/pytorch-lightning/pull/10756))


-
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,13 @@ class ResultMetricCollection(dict):
with the same metadata.
"""

def __init__(self, *args: Any) -> None:
super().__init__(*args)

@property
def meta(self) -> _Metadata:
return list(self.values())[0].meta
return next(iter(self.values())).meta

@property
def has_tensor(self) -> bool:
return any(v.is_tensor for v in self.values())

def __getstate__(self, drop_value: bool = False) -> dict:
def getstate(item: ResultMetric) -> dict:
Expand Down Expand Up @@ -403,7 +404,7 @@ def append_fn(v: ResultMetric) -> None:
apply_to_collection(list(self.values()), ResultMetric, append_fn)
return o

def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int:
def _extract_batch_size(self, value: _METRIC_COLLECTION, batch_size: Optional[int], meta: _Metadata) -> int:
# check if we have extracted the batch size already
if batch_size is None:
batch_size = self.batch_size
Expand All @@ -412,7 +413,8 @@ def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int
return batch_size

batch_size = 1
if self.batch is not None and meta.on_epoch and meta.is_mean_reduction:
is_tensor = value.is_tensor if isinstance(value, ResultMetric) else value.has_tensor
if self.batch is not None and is_tensor and meta.on_epoch and meta.is_mean_reduction:
batch_size = extract_batch_size(self.batch)
self.batch_size = batch_size

Expand Down Expand Up @@ -477,7 +479,7 @@ def log(
f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed"
)

batch_size = self._extract_batch_size(batch_size, meta)
batch_size = self._extract_batch_size(self[key], batch_size, meta)
self.update_metrics(key, value, batch_size)

def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None:
Expand Down
50 changes: 49 additions & 1 deletion tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest
import torch
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, AveragePrecision
from torchmetrics import Accuracy, AveragePrecision, MeanAbsoluteError, MeanSquaredError

from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks.base import Callback
Expand Down Expand Up @@ -640,3 +640,51 @@ def training_step(self, batch, batch_idx):

# should not get overridden if logged manually
assert trainer.logged_metrics == {"epoch": -1}


def test_result_collection_batch_size_extraction():
fx_name = "training_step"
log_val = torch.tensor(7.0)

results = ResultCollection(training=True, device="cpu")
results.batch = torch.randn(1, 4)
train_mse = MeanSquaredError()
train_mse(torch.randn(4, 5), torch.randn(4, 5))
results.log(fx_name, "train_logs", {"mse": train_mse, "log_val": log_val}, on_step=False, on_epoch=True)
assert results.batch_size == 1
assert isinstance(results["training_step.train_logs"]["mse"].value, MeanSquaredError)
assert results["training_step.train_logs"]["log_val"].value == log_val

results = ResultCollection(training=True, device="cpu")
results.batch = torch.randn(1, 4)
results.log(fx_name, "train_log", log_val, on_step=False, on_epoch=True)
assert results.batch_size == 1
assert results["training_step.train_log"].value == log_val
assert results["training_step.train_log"].cumulated_batch_size == 1


def test_result_collection_no_batch_size_extraction():
results = ResultCollection(training=True, device="cpu")
results.batch = torch.randn(1, 4)
fx_name = "training_step"
batch_size = 10
log_val = torch.tensor(7.0)

train_mae = MeanAbsoluteError()
train_mae(torch.randn(4, 5), torch.randn(4, 5))
train_mse = MeanSquaredError()
train_mse(torch.randn(4, 5), torch.randn(4, 5))
results.log(fx_name, "step_log_val", log_val, on_step=True, on_epoch=False)
results.log(fx_name, "epoch_log_val", log_val, on_step=False, on_epoch=True, batch_size=batch_size)
results.log(fx_name, "epoch_sum_log_val", log_val, on_step=True, on_epoch=True, reduce_fx="sum")
results.log(fx_name, "train_mae", train_mae, on_step=True, on_epoch=False)
results.log(fx_name, "train_mse", {"mse": train_mse}, on_step=True, on_epoch=False)

assert results.batch_size is None
assert isinstance(results["training_step.train_mse"]["mse"].value, MeanSquaredError)
assert isinstance(results["training_step.train_mae"].value, MeanAbsoluteError)
assert results["training_step.step_log_val"].value == log_val
assert results["training_step.step_log_val"].cumulated_batch_size == 0
assert results["training_step.epoch_log_val"].value == log_val * batch_size
assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size
assert results["training_step.epoch_sum_log_val"].value == log_val
34 changes: 0 additions & 34 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.deprecated_api import no_warning_call
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomDictDataset
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -746,36 +745,3 @@ def validation_epoch_end(self, *_) -> None:
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)


def test_no_batch_size_extraction_with_specifying_explictly(tmpdir):
batch_size = BoringModel().train_dataloader().batch_size + 1
fast_dev_run = 2
log_val = 7

class CustomBoringModel(BoringModel):
def on_before_batch_transfer(self, batch, *args, **kwargs):
# This is an ambiguous batch which have multiple potential batch sizes
if self.trainer.training:
batch = {"batch1": torch.randn(batch_size, 10), "batch2": batch}
return batch

def training_step(self, batch, batch_idx):
self.log("step_log_val", log_val, on_epoch=False)
self.log("epoch_log_val", log_val, batch_size=batch_size, on_step=False, on_epoch=True)
self.log("epoch_sum_log_val", log_val, on_epoch=True, reduce_fx="sum")
return super().training_step(batch["batch2"], batch_idx)

def on_train_epoch_end(self, *args, **kwargs):
results = self.trainer._results
assert results["training_step.step_log_val"].value == log_val
assert results["training_step.step_log_val"].cumulated_batch_size == 0
assert results["training_step.epoch_log_val"].value == log_val * batch_size * fast_dev_run
assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size * fast_dev_run
assert results["training_step.epoch_sum_log_val"].value == log_val * fast_dev_run

model = CustomBoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=fast_dev_run)

with no_warning_call(match="Trying to infer the `batch_size`"):
trainer.fit(model)