Skip to content

Commit

Permalink
Fix batch size extraction when set by the user in `LightningModule.lo…
Browse files Browse the repository at this point in the history
…g` (#10408)

Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
rohitgr7 and carmocca committed Nov 20, 2021
1 parent db1b960 commit affd617
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 62 deletions.
6 changes: 1 addition & 5 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

self.batch_progress.increment_ready()

# cache the batch size value to avoid extracting it again after the batch loop runs as the value will be
# different if tbptt is enabled
batch_size = self.trainer.logger_connector.on_batch_start(batch_idx, batch)
self.trainer.logger_connector.on_batch_start(batch_idx, batch)

if batch is None:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
Expand Down Expand Up @@ -194,8 +192,6 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, batch_idx)

self.trainer._results.batch_size = batch_size

self.batch_progress.increment_processed()

# update non-plateau LR schedulers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:

def on_train_split_start(self, split_idx: int, split_batch: Any) -> None:
self._split_idx = split_idx
self.on_new_batch(split_batch)

def update_train_step_metrics(self) -> None:
if self.trainer.fit_loop._should_accumulate() and self.trainer.lightning_module.automatic_optimization:
Expand Down Expand Up @@ -253,28 +252,23 @@ def _log_gpus_metrics(self) -> None:
Utilities and properties
"""

def on_new_batch(self, batch: Any) -> int:
# when the user requests `dataloader_iter`, we can't track the batch_size
# and this is left to user responsibility.
if not isinstance(batch, pl.utilities.fetching.StepFuncDataLoaderIter):
assert self.trainer._results is not None
return self.trainer._results.extract_batch_size(batch)
return 1

def on_epoch_start(self) -> None:
self._epoch_end_reached = False

def on_batch_start(self, batch_idx: int, batch: Any) -> int:
def on_batch_start(self, batch_idx: int, batch: Any) -> None:
self._batch_idx = batch_idx
self._epoch_end_reached = False
return self.on_new_batch(batch)

assert self.trainer._results is not None
# attach reference to the new batch and remove the cached batch_size
self.trainer._results.batch = batch
self.trainer._results.batch_size = None

def epoch_end_reached(self) -> None:
self._epoch_end_reached = True
self._batch_idx = None
self._split_idx = None
assert self.trainer._results is not None
self.trainer._results.batch_size = 1

def on_epoch_end(self) -> None:
assert self._epoch_end_reached
Expand All @@ -291,6 +285,11 @@ def on_batch_end(self) -> None:
self._callback_metrics.update(metrics["callback"])
self._logged_metrics.update(metrics["log"])

assert self.trainer._results is not None
# drop the reference to current batch and batch_size
self.trainer._results.batch = None
self.trainer._results.batch_size = None

def should_reset_tensors(self, fx: str) -> bool:
is_different_fx = self._current_fx != fx
if self._split_idx is None:
Expand Down
51 changes: 25 additions & 26 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
if self.meta.is_mean_reduction:
self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float), dist_reduce_fx=torch.sum)

def update(self, value: _IN_METRIC, batch_size: torch.Tensor) -> None:
def update(self, value: _IN_METRIC, batch_size: int) -> None:
if self.is_tensor:
value = value.float()
if self.meta.on_step:
Expand Down Expand Up @@ -250,7 +250,7 @@ def reset(self) -> None:
self.value.reset()
self.has_reset = True

def forward(self, value: _IN_METRIC, batch_size: torch.Tensor) -> None:
def forward(self, value: _IN_METRIC, batch_size: int) -> None:
if self.meta.enable_graph:
with torch.no_grad():
self.update(value, batch_size)
Expand Down Expand Up @@ -376,8 +376,9 @@ class ResultCollection(dict):
def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None:
super().__init__()
self.training = training
self._batch_size = torch.tensor(1, device=device)
self.device: Optional[Union[str, torch.device]] = device
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None

@property
def result_metrics(self) -> List[ResultMetric]:
Expand All @@ -390,14 +391,23 @@ def append_fn(v: ResultMetric) -> None:
apply_to_collection(list(self.values()), ResultMetric, append_fn)
return o

@property
def batch_size(self) -> torch.Tensor:
# performance: cache the `batch_size` tensor instead of re-creating it
return self._batch_size
def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int:
# check if we have extracted the batch size already
if batch_size is None:
batch_size = self.batch_size

if batch_size is not None:
return batch_size

@batch_size.setter
def batch_size(self, value: int) -> None:
self._batch_size = torch.tensor(value, device=self.device)
batch_size = 1
if self.batch is not None and meta.on_epoch and meta.is_mean_reduction:
try:
batch_size = extract_batch_size(self.batch)
self.batch_size = batch_size
except RecursionError:
pass

return batch_size

def log(
self,
Expand Down Expand Up @@ -458,10 +468,8 @@ def log(
f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed"
)

if batch_size is not None:
self.batch_size = batch_size

self.update_metrics(key, value)
batch_size = self._extract_batch_size(batch_size, meta)
self.update_metrics(key, value, batch_size)

def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None:
"""Create one ResultMetric object per value.
Expand All @@ -478,10 +486,10 @@ def fn(v: _IN_METRIC) -> ResultMetric:
value = ResultMetricCollection(value)
self[key] = value

def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None:
def fn(result_metric: ResultMetric, v: ResultMetric) -> None:
def update_metrics(self, key: str, value: _METRIC_COLLECTION, batch_size: int) -> None:
def fn(result_metric: ResultMetric, v: torch.Tensor) -> None:
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
result_metric.forward(v.to(self.device), self.batch_size)
result_metric.forward(v.to(self.device), batch_size)
result_metric.has_reset = False

apply_to_collections(self[key], value, ResultMetric, fn)
Expand Down Expand Up @@ -575,19 +583,10 @@ def fn(item: ResultMetric) -> None:

apply_to_collection(self, ResultMetric, fn)

def extract_batch_size(self, batch: Any) -> int:
try:
batch_size = extract_batch_size(batch)
except RecursionError:
batch_size = 1
self.batch_size = batch_size # the setter converts it to `Tensor`
return batch_size

def to(self, *args: Any, **kwargs: Any) -> "ResultCollection":
"""Move all data to the given device."""
self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs))

self._batch_size = self._batch_size.to(*args, **kwargs)
if "device" in kwargs:
self.device = kwargs["device"]
return self
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@

def _extract_batch_size(batch: BType) -> Generator[int, None, None]:
if isinstance(batch, torch.Tensor):
yield batch.size(0)
if batch.ndim == 0:
yield 1
else:
yield batch.size(0)
elif isinstance(batch, str):
yield len(batch)
elif isinstance(batch, (Iterable, Mapping)):
Expand Down
28 changes: 21 additions & 7 deletions tests/deprecated_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Test deprecated functionality which will be removed in vX.Y.Z."""
import sys
from contextlib import contextmanager
from typing import Optional
from typing import Optional, Type

import pytest

Expand All @@ -26,14 +26,28 @@ def _soft_unimport_module(str_module):


@contextmanager
def no_deprecated_call(match: Optional[str] = None):
def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None):
with pytest.warns(None) as record:
yield

if match is None:
try:
w = record.pop(DeprecationWarning)
if match is not None and match not in str(w.message):
return
w = record.pop(expected_warning)
except AssertionError:
# no DeprecationWarning raised
# no warning raised
return
else:
for w in record.list:
if w.category is expected_warning and match in w.message.args[0]:
break
else:
return
raise AssertionError(f"`DeprecationWarning` was raised: {w}")

msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`"
raise AssertionError(f"{msg} was raised: {w}")


@contextmanager
def no_deprecated_call(match: Optional[str] = None):
with no_warning_call(expected_warning=DeprecationWarning, match=match):
yield
13 changes: 8 additions & 5 deletions tests/loops/test_loop_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from unittest.mock import Mock

import pytest
import torch

from pytorch_lightning.loops import FitLoop
from pytorch_lightning.trainer.trainer import Trainer
Expand Down Expand Up @@ -80,14 +79,16 @@ def test_loops_state_dict_structure():
"is_last_batch": False,
},
"epoch_loop.val_loop._results": {
"batch": None,
"batch_size": None,
"training": False,
"_batch_size": torch.tensor(1),
"device": None,
"items": {},
},
"epoch_loop._results": {
"batch": None,
"batch_size": None,
"training": True,
"_batch_size": torch.tensor(1),
"device": None,
"items": {},
},
Expand All @@ -106,8 +107,9 @@ def test_loops_state_dict_structure():
"is_last_batch": False,
},
"_results": {
"batch": None,
"batch_size": None,
"training": False,
"_batch_size": torch.tensor(1),
"device": None,
"items": {},
},
Expand All @@ -122,8 +124,9 @@ def test_loops_state_dict_structure():
"is_last_batch": False,
},
"_results": {
"batch": None,
"batch_size": None,
"training": False,
"_batch_size": torch.tensor(1),
"device": None,
"items": {},
},
Expand Down
38 changes: 34 additions & 4 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.deprecated_api import no_warning_call
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomDictDataset
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -715,19 +716,15 @@ def on_validation_epoch_end(self):
assert all(v == 3 for v in self.trainer.callback_metrics.values())

def on_train_batch_start(self, batch, batch_idx):
assert self.trainer._results.batch_size == 2
self.log("on_train_batch_start", 1.0, reduce_fx="sum")

def on_train_batch_end(self, outputs, batch, batch_idx):
assert self.trainer._results.batch_size == 2
self.log("on_train_batch_end", 1.0, reduce_fx="sum")

def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
assert self.trainer._results.batch_size == 2
self.log("on_validation_batch_start", 1.0, reduce_fx="sum")

def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
assert self.trainer._results.batch_size == 2
self.log("on_validation_batch_end", 1.0, reduce_fx="sum")

def training_epoch_end(self, *_) -> None:
Expand All @@ -749,3 +746,36 @@ def validation_epoch_end(self, *_) -> None:
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)


def test_no_batch_size_extraction_with_specifying_explictly(tmpdir):
batch_size = BoringModel().train_dataloader().batch_size + 1
fast_dev_run = 2
log_val = 7

class CustomBoringModel(BoringModel):
def on_before_batch_transfer(self, batch, *args, **kwargs):
# This is an ambiguous batch which have multiple potential batch sizes
if self.trainer.training:
batch = {"batch1": torch.randn(batch_size, 10), "batch2": batch}
return batch

def training_step(self, batch, batch_idx):
self.log("step_log_val", log_val, on_epoch=False)
self.log("epoch_log_val", log_val, batch_size=batch_size, on_step=False, on_epoch=True)
self.log("epoch_sum_log_val", log_val, on_epoch=True, reduce_fx="sum")
return super().training_step(batch["batch2"], batch_idx)

def on_train_epoch_end(self, *args, **kwargs):
results = self.trainer._results
assert results["training_step.step_log_val"].value == log_val
assert results["training_step.step_log_val"].cumulated_batch_size == 0
assert results["training_step.epoch_log_val"].value == log_val * batch_size * fast_dev_run
assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size * fast_dev_run
assert results["training_step.epoch_sum_log_val"].value == log_val * fast_dev_run

model = CustomBoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=fast_dev_run)

with no_warning_call(match="Trying to infer the `batch_size`"):
trainer.fit(model)
7 changes: 5 additions & 2 deletions tests/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
warning_cache,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.deprecated_api import no_warning_call
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset


def test_extract_batch_size():
"""Tests the behavior of extracting the batch size."""

def _check_warning_not_raised(data, expected):
with pytest.warns(None) as record:
with no_warning_call(match="Trying to infer the `batch_size`"):
assert extract_batch_size(data) == expected
assert len(record) == 0

def _check_warning_raised(data, expected):
with pytest.warns(UserWarning, match=f"Trying to infer the `batch_size` .* we found is {expected}."):
Expand All @@ -43,6 +43,9 @@ def _check_warning_raised(data, expected):
batch = {"test": [{"test": [torch.zeros(11, 10)]}]}
_check_warning_not_raised(batch, 11)

batch = {"a": [torch.tensor(1), torch.tensor(2)], "b": torch.tensor([1, 2, 3, 4])}
_check_warning_raised(batch, 1)

batch = {"test": [{"test": [torch.zeros(11, 10), torch.zeros(10, 10)]}]}
_check_warning_raised(batch, 11)

Expand Down

0 comments on commit affd617

Please sign in to comment.