Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix an issue that caused Lightning to extract the batch size even though it was set by the user in LightningModule.log #10408

Merged
merged 41 commits into from
Nov 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
5e10b93
extract batch size only when it's required
rohitgr7 Nov 8, 2021
8941e32
revamp log
rohitgr7 Nov 8, 2021
51a7d36
fix
rohitgr7 Nov 8, 2021
c2a3626
fix
rohitgr7 Nov 8, 2021
be5882c
update tests
rohitgr7 Nov 8, 2021
bba8055
chlog
rohitgr7 Nov 8, 2021
308f9c0
fix
rohitgr7 Nov 9, 2021
04314b8
add test
rohitgr7 Nov 9, 2021
b1de03e
Merge branch 'master' into fix/batch_size
rohitgr7 Nov 9, 2021
46a8861
mypy
rohitgr7 Nov 9, 2021
9b28e44
mypy
rohitgr7 Nov 9, 2021
0772967
deref current batch
rohitgr7 Nov 9, 2021
c3567ff
mypy
rohitgr7 Nov 9, 2021
2ec9885
Merge branch 'master' into fix/batch_size
rohitgr7 Nov 9, 2021
1484284
rev
rohitgr7 Nov 9, 2021
4363e69
cache batch size
rohitgr7 Nov 10, 2021
94c5749
Merge branch 'master' into fix/batch_size
rohitgr7 Nov 10, 2021
cb9faae
update test
rohitgr7 Nov 10, 2021
61b9483
mypy
rohitgr7 Nov 10, 2021
35c37ac
mypy
rohitgr7 Nov 10, 2021
46c22be
move to resultcollection
rohitgr7 Nov 10, 2021
9b25ecf
mypy
rohitgr7 Nov 10, 2021
b7a2296
update logic
rohitgr7 Nov 10, 2021
55a189e
Apply suggestions from code review
rohitgr7 Nov 18, 2021
e89b6d2
Merge remote-tracking branch 'origin/master' into fix/batch_size
rohitgr7 Nov 18, 2021
c9a8543
chlog
rohitgr7 Nov 18, 2021
a053dd1
update on comments
rohitgr7 Nov 18, 2021
a61dad1
Use our utilities
carmocca Nov 19, 2021
b9d2a56
Whitespace
carmocca Nov 19, 2021
f87d215
Remove unnecesary properties
carmocca Nov 19, 2021
0074b81
Remove current prefix
carmocca Nov 19, 2021
cb9de15
Simplify arguments
carmocca Nov 19, 2021
92cd293
Avoid indentations
carmocca Nov 19, 2021
6c9522d
Cache only if succesfully extracted
carmocca Nov 19, 2021
2979772
Merge branch 'master' into fix/batch_size
carmocca Nov 19, 2021
bb6f3c4
minor updates
rohitgr7 Nov 19, 2021
250defa
Simplify check
carmocca Nov 19, 2021
1ee9b17
Remove silly comment
carmocca Nov 19, 2021
b806ab0
Merge branch 'master' into fix/batch_size
carmocca Nov 19, 2021
560da62
mypy fix
carmocca Nov 19, 2021
4c859d3
Merge branch 'master' into fix/batch_size
carmocca Nov 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed signals being registered within threads ([#10610](https://github.com/PyTorchLightning/pytorch-lightning/pull/10610))


- Fixed an issue that caused Lightning to extract the batch size even though it was set by the user in `LightningModule.log` ([#10408](https://github.com/PyTorchLightning/pytorch-lightning/pull/10408))


-


Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

self.batch_progress.increment_ready()

# cache the batch size value to avoid extracting it again after the batch loop runs as the value will be
# different if tbptt is enabled
batch_size = self.trainer.logger_connector.on_batch_start(batch_idx, batch)
self.trainer.logger_connector.on_batch_start(batch_idx, batch)

if batch is None:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
Expand Down Expand Up @@ -194,8 +192,6 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, batch_idx)

self.trainer._results.batch_size = batch_size
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

self.batch_progress.increment_processed()

# update non-plateau LR schedulers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:

def on_train_split_start(self, split_idx: int, split_batch: Any) -> None:
self._split_idx = split_idx
self.on_new_batch(split_batch)

def update_train_step_metrics(self) -> None:
if self.trainer.fit_loop._should_accumulate() and self.trainer.lightning_module.automatic_optimization:
Expand Down Expand Up @@ -257,28 +256,23 @@ def _log_gpus_metrics(self) -> None:
Utilities and properties
"""

def on_new_batch(self, batch: Any) -> int:
# when the user requests `dataloader_iter`, we can't track the batch_size
# and this is left to user responsibility.
if not isinstance(batch, pl.utilities.fetching.StepFuncDataLoaderIter):
assert self.trainer._results is not None
return self.trainer._results.extract_batch_size(batch)
return 1

def on_epoch_start(self) -> None:
self._epoch_end_reached = False

def on_batch_start(self, batch_idx: int, batch: Any) -> int:
def on_batch_start(self, batch_idx: int, batch: Any) -> None:
self._batch_idx = batch_idx
self._epoch_end_reached = False
return self.on_new_batch(batch)

assert self.trainer._results is not None
# attach reference to the new batch and remove the cached batch_size
self.trainer._results.batch = batch
self.trainer._results.batch_size = None

def epoch_end_reached(self) -> None:
self._epoch_end_reached = True
self._batch_idx = None
self._split_idx = None
assert self.trainer._results is not None
self.trainer._results.batch_size = 1

def on_epoch_end(self) -> None:
assert self._epoch_end_reached
Expand All @@ -295,6 +289,11 @@ def on_batch_end(self) -> None:
self._callback_metrics.update(metrics["callback"])
self._logged_metrics.update(metrics["log"])

assert self.trainer._results is not None
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
# drop the reference to current batch and batch_size
self.trainer._results.batch = None
self.trainer._results.batch_size = None

def should_reset_tensors(self, fx: str) -> bool:
is_different_fx = self._current_fx != fx
if self._split_idx is None:
Expand Down
51 changes: 25 additions & 26 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
if self.meta.is_mean_reduction:
self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum)

def update(self, value: _IN_METRIC, batch_size: torch.Tensor) -> None:
def update(self, value: _IN_METRIC, batch_size: int) -> None:
if self.is_tensor:
if not torch.is_floating_point(value):
dtype = torch.get_default_dtype()
Expand Down Expand Up @@ -259,7 +259,7 @@ def reset(self) -> None:
self.value.reset()
self.has_reset = True

def forward(self, value: _IN_METRIC, batch_size: torch.Tensor) -> None:
def forward(self, value: _IN_METRIC, batch_size: int) -> None:
if self.meta.enable_graph:
with torch.no_grad():
self.update(value, batch_size)
Expand Down Expand Up @@ -385,8 +385,9 @@ class ResultCollection(dict):
def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None:
super().__init__()
self.training = training
self._batch_size = torch.tensor(1, device=device)
self.device: Optional[Union[str, torch.device]] = device
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None

@property
def result_metrics(self) -> List[ResultMetric]:
Expand All @@ -399,14 +400,23 @@ def append_fn(v: ResultMetric) -> None:
apply_to_collection(list(self.values()), ResultMetric, append_fn)
return o

@property
def batch_size(self) -> torch.Tensor:
# performance: cache the `batch_size` tensor instead of re-creating it
return self._batch_size
def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int:
# check if we have extracted the batch size already
if batch_size is None:
batch_size = self.batch_size

if batch_size is not None:
return batch_size

@batch_size.setter
def batch_size(self, value: int) -> None:
self._batch_size = torch.tensor(value, device=self.device)
batch_size = 1
if self.batch is not None and meta.on_epoch and meta.is_mean_reduction:
try:
batch_size = extract_batch_size(self.batch)
self.batch_size = batch_size
except RecursionError:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
pass

return batch_size

def log(
self,
Expand Down Expand Up @@ -467,10 +477,8 @@ def log(
f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed"
)

if batch_size is not None:
self.batch_size = batch_size

self.update_metrics(key, value)
batch_size = self._extract_batch_size(batch_size, meta)
self.update_metrics(key, value, batch_size)

def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None:
"""Create one ResultMetric object per value.
Expand All @@ -487,10 +495,10 @@ def fn(v: _IN_METRIC) -> ResultMetric:
value = ResultMetricCollection(value)
self[key] = value

def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None:
def fn(result_metric: ResultMetric, v: ResultMetric) -> None:
def update_metrics(self, key: str, value: _METRIC_COLLECTION, batch_size: int) -> None:
def fn(result_metric: ResultMetric, v: torch.Tensor) -> None:
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
result_metric.forward(v.to(self.device), self.batch_size)
result_metric.forward(v.to(self.device), batch_size)
result_metric.has_reset = False

apply_to_collections(self[key], value, ResultMetric, fn)
Expand Down Expand Up @@ -584,19 +592,10 @@ def fn(item: ResultMetric) -> None:

apply_to_collection(self, ResultMetric, fn)

def extract_batch_size(self, batch: Any) -> int:
try:
batch_size = extract_batch_size(batch)
except RecursionError:
batch_size = 1
self.batch_size = batch_size # the setter converts it to `Tensor`
return batch_size

def to(self, *args: Any, **kwargs: Any) -> "ResultCollection":
"""Move all data to the given device."""
self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs))

self._batch_size = self._batch_size.to(*args, **kwargs)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if "device" in kwargs:
self.device = kwargs["device"]
return self
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@

def _extract_batch_size(batch: BType) -> Generator[int, None, None]:
if isinstance(batch, torch.Tensor):
yield batch.size(0)
if batch.ndim == 0:
yield 1
else:
yield batch.size(0)
elif isinstance(batch, str):
yield len(batch)
elif isinstance(batch, (Iterable, Mapping)):
Expand Down
28 changes: 21 additions & 7 deletions tests/deprecated_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Test deprecated functionality which will be removed in vX.Y.Z."""
import sys
from contextlib import contextmanager
from typing import Optional
from typing import Optional, Type

import pytest

Expand All @@ -26,14 +26,28 @@ def _soft_unimport_module(str_module):


@contextmanager
def no_deprecated_call(match: Optional[str] = None):
def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None):
with pytest.warns(None) as record:
yield

if match is None:
try:
w = record.pop(DeprecationWarning)
if match is not None and match not in str(w.message):
return
w = record.pop(expected_warning)
except AssertionError:
# no DeprecationWarning raised
# no warning raised
return
else:
for w in record.list:
if w.category is expected_warning and match in w.message.args[0]:
break
else:
return
raise AssertionError(f"`DeprecationWarning` was raised: {w}")

msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`"
raise AssertionError(f"{msg} was raised: {w}")


@contextmanager
def no_deprecated_call(match: Optional[str] = None):
with no_warning_call(expected_warning=DeprecationWarning, match=match):
yield
13 changes: 8 additions & 5 deletions tests/loops/test_loop_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from unittest.mock import Mock

import pytest
import torch

from pytorch_lightning.loops import FitLoop
from pytorch_lightning.trainer.trainer import Trainer
Expand Down Expand Up @@ -80,14 +79,16 @@ def test_loops_state_dict_structure():
"is_last_batch": False,
},
"epoch_loop.val_loop._results": {
"batch": None,
"batch_size": None,
"training": False,
"_batch_size": torch.tensor(1),
"device": None,
"items": {},
},
"epoch_loop._results": {
"batch": None,
"batch_size": None,
"training": True,
"_batch_size": torch.tensor(1),
"device": None,
"items": {},
},
Expand All @@ -106,8 +107,9 @@ def test_loops_state_dict_structure():
"is_last_batch": False,
},
"_results": {
"batch": None,
"batch_size": None,
"training": False,
"_batch_size": torch.tensor(1),
"device": None,
"items": {},
},
Expand All @@ -122,8 +124,9 @@ def test_loops_state_dict_structure():
"is_last_batch": False,
},
"_results": {
"batch": None,
"batch_size": None,
"training": False,
"_batch_size": torch.tensor(1),
"device": None,
"items": {},
},
Expand Down
38 changes: 34 additions & 4 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.deprecated_api import no_warning_call
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomDictDataset
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -715,19 +716,15 @@ def on_validation_epoch_end(self):
assert all(v == 3 for v in self.trainer.callback_metrics.values())

def on_train_batch_start(self, batch, batch_idx):
assert self.trainer._results.batch_size == 2
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.log("on_train_batch_start", 1.0, reduce_fx="sum")

def on_train_batch_end(self, outputs, batch, batch_idx):
assert self.trainer._results.batch_size == 2
self.log("on_train_batch_end", 1.0, reduce_fx="sum")

def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
assert self.trainer._results.batch_size == 2
self.log("on_validation_batch_start", 1.0, reduce_fx="sum")

def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
assert self.trainer._results.batch_size == 2
self.log("on_validation_batch_end", 1.0, reduce_fx="sum")

def training_epoch_end(self, *_) -> None:
Expand All @@ -749,3 +746,36 @@ def validation_epoch_end(self, *_) -> None:
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)


def test_no_batch_size_extraction_with_specifying_explictly(tmpdir):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
batch_size = BoringModel().train_dataloader().batch_size + 1
fast_dev_run = 2
log_val = 7

class CustomBoringModel(BoringModel):
def on_before_batch_transfer(self, batch, *args, **kwargs):
# This is an ambiguous batch which have multiple potential batch sizes
if self.trainer.training:
batch = {"batch1": torch.randn(batch_size, 10), "batch2": batch}
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return batch

def training_step(self, batch, batch_idx):
self.log("step_log_val", log_val, on_epoch=False)
self.log("epoch_log_val", log_val, batch_size=batch_size, on_step=False, on_epoch=True)
self.log("epoch_sum_log_val", log_val, on_epoch=True, reduce_fx="sum")
return super().training_step(batch["batch2"], batch_idx)

def on_train_epoch_end(self, *args, **kwargs):
results = self.trainer._results
assert results["training_step.step_log_val"].value == log_val
assert results["training_step.step_log_val"].cumulated_batch_size == 0
assert results["training_step.epoch_log_val"].value == log_val * batch_size * fast_dev_run
assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size * fast_dev_run
assert results["training_step.epoch_sum_log_val"].value == log_val * fast_dev_run

model = CustomBoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=fast_dev_run)

with no_warning_call(match="Trying to infer the `batch_size`"):
trainer.fit(model)
7 changes: 5 additions & 2 deletions tests/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
warning_cache,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.deprecated_api import no_warning_call
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset


def test_extract_batch_size():
"""Tests the behavior of extracting the batch size."""

def _check_warning_not_raised(data, expected):
with pytest.warns(None) as record:
with no_warning_call(match="Trying to infer the `batch_size`"):
assert extract_batch_size(data) == expected
assert len(record) == 0

def _check_warning_raised(data, expected):
with pytest.warns(UserWarning, match=f"Trying to infer the `batch_size` .* we found is {expected}."):
Expand All @@ -43,6 +43,9 @@ def _check_warning_raised(data, expected):
batch = {"test": [{"test": [torch.zeros(11, 10)]}]}
_check_warning_not_raised(batch, 11)

batch = {"a": [torch.tensor(1), torch.tensor(2)], "b": torch.tensor([1, 2, 3, 4])}
_check_warning_raised(batch, 1)

batch = {"test": [{"test": [torch.zeros(11, 10), torch.zeros(10, 10)]}]}
_check_warning_raised(batch, 11)

Expand Down