Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable batch_size extraction for torchmetric instances #10815

Merged
merged 9 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746))


- Disable batch_size extraction for torchmetric instances ([#10815](https://github.com/PyTorchLightning/pytorch-lightning/pull/10815))
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved


-


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def append_fn(v: ResultMetric) -> None:
apply_to_collection(list(self.values()), ResultMetric, append_fn)
return o

def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int:
def _extract_batch_size(self, value: _METRIC_COLLECTION, batch_size: Optional[int], meta: _Metadata) -> int:
# check if we have extracted the batch size already
if batch_size is None:
batch_size = self.batch_size
Expand All @@ -412,7 +412,12 @@ def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int
return batch_size

batch_size = 1
if self.batch is not None and meta.on_epoch and meta.is_mean_reduction:
is_tensor = (
value.is_tensor
if isinstance(value, ResultMetric)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
else any(isinstance(val, ResultMetric) for val in value.values())
carmocca marked this conversation as resolved.
Show resolved Hide resolved
)
if self.batch is not None and is_tensor and meta.on_epoch and meta.is_mean_reduction:
batch_size = extract_batch_size(self.batch)
self.batch_size = batch_size

Expand Down Expand Up @@ -477,7 +482,7 @@ def log(
f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed"
)

batch_size = self._extract_batch_size(batch_size, meta)
batch_size = self._extract_batch_size(self[key], batch_size, meta)
self.update_metrics(key, value, batch_size)

def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None:
Expand Down
28 changes: 15 additions & 13 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
import collections
import itertools
from re import escape
from unittest import mock

import numpy as np
import pytest
import torch
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from torchmetrics import Accuracy, MeanAbsoluteError

from pytorch_lightning import callbacks, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.deprecated_api import no_warning_call
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomDictDataset
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -748,26 +748,29 @@ def validation_epoch_end(self, *_) -> None:
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)


def test_no_batch_size_extraction_with_specifying_explictly(tmpdir):
batch_size = BoringModel().train_dataloader().batch_size + 1
@mock.patch("pytorch_lightning.utilities.data.extract_batch_size")
def test_no_batch_size_extraction_with_specifying_explictly(mock_method, tmpdir):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
batch_size = 10
fast_dev_run = 2
log_val = 7

class CustomBoringModel(BoringModel):
def on_before_batch_transfer(self, batch, *args, **kwargs):
# This is an ambiguous batch which have multiple potential batch sizes
if self.trainer.training:
batch = {"batch1": torch.randn(batch_size, 10), "batch2": batch}
return batch
def __init__(self):
super().__init__()
self.train_mae = MeanAbsoluteError()

def training_step(self, batch, batch_idx):
self.log("step_log_val", log_val, on_epoch=False)
self.log("epoch_log_val", log_val, batch_size=batch_size, on_step=False, on_epoch=True)
self.log("epoch_sum_log_val", log_val, on_epoch=True, reduce_fx="sum")
return super().training_step(batch["batch2"], batch_idx)

self.train_mae(batch, torch.ones_like(batch))
self.log("train_mae", self.train_mae)
return super().training_step(batch, batch_idx)

def on_train_epoch_end(self, *args, **kwargs):
results = self.trainer._results
assert isinstance(results["training_step.train_mae"].value, MeanAbsoluteError)
assert results["training_step.step_log_val"].value == log_val
assert results["training_step.step_log_val"].cumulated_batch_size == 0
assert results["training_step.epoch_log_val"].value == log_val * batch_size * fast_dev_run
Expand All @@ -776,6 +779,5 @@ def on_train_epoch_end(self, *args, **kwargs):

model = CustomBoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=fast_dev_run)

with no_warning_call(match="Trying to infer the `batch_size`"):
trainer.fit(model)
trainer.fit(model)
mock_method.assert_not_called()