Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix an issue that caused Lightning to extract the batch size even though it was set by the user in LightningModule.log #10408

Merged
merged 41 commits into from
Nov 19, 2021
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
5e10b93
extract batch size only when it's required
rohitgr7 Nov 8, 2021
8941e32
revamp log
rohitgr7 Nov 8, 2021
51a7d36
fix
rohitgr7 Nov 8, 2021
c2a3626
fix
rohitgr7 Nov 8, 2021
be5882c
update tests
rohitgr7 Nov 8, 2021
bba8055
chlog
rohitgr7 Nov 8, 2021
308f9c0
fix
rohitgr7 Nov 9, 2021
04314b8
add test
rohitgr7 Nov 9, 2021
b1de03e
Merge branch 'master' into fix/batch_size
rohitgr7 Nov 9, 2021
46a8861
mypy
rohitgr7 Nov 9, 2021
9b28e44
mypy
rohitgr7 Nov 9, 2021
0772967
deref current batch
rohitgr7 Nov 9, 2021
c3567ff
mypy
rohitgr7 Nov 9, 2021
2ec9885
Merge branch 'master' into fix/batch_size
rohitgr7 Nov 9, 2021
1484284
rev
rohitgr7 Nov 9, 2021
4363e69
cache batch size
rohitgr7 Nov 10, 2021
94c5749
Merge branch 'master' into fix/batch_size
rohitgr7 Nov 10, 2021
cb9faae
update test
rohitgr7 Nov 10, 2021
61b9483
mypy
rohitgr7 Nov 10, 2021
35c37ac
mypy
rohitgr7 Nov 10, 2021
46c22be
move to resultcollection
rohitgr7 Nov 10, 2021
9b25ecf
mypy
rohitgr7 Nov 10, 2021
b7a2296
update logic
rohitgr7 Nov 10, 2021
55a189e
Apply suggestions from code review
rohitgr7 Nov 18, 2021
e89b6d2
Merge remote-tracking branch 'origin/master' into fix/batch_size
rohitgr7 Nov 18, 2021
c9a8543
chlog
rohitgr7 Nov 18, 2021
a053dd1
update on comments
rohitgr7 Nov 18, 2021
a61dad1
Use our utilities
carmocca Nov 19, 2021
b9d2a56
Whitespace
carmocca Nov 19, 2021
f87d215
Remove unnecesary properties
carmocca Nov 19, 2021
0074b81
Remove current prefix
carmocca Nov 19, 2021
cb9de15
Simplify arguments
carmocca Nov 19, 2021
92cd293
Avoid indentations
carmocca Nov 19, 2021
6c9522d
Cache only if succesfully extracted
carmocca Nov 19, 2021
2979772
Merge branch 'master' into fix/batch_size
carmocca Nov 19, 2021
bb6f3c4
minor updates
rohitgr7 Nov 19, 2021
250defa
Simplify check
carmocca Nov 19, 2021
1ee9b17
Remove silly comment
carmocca Nov 19, 2021
b806ab0
Merge branch 'master' into fix/batch_size
carmocca Nov 19, 2021
560da62
mypy fix
carmocca Nov 19, 2021
4c859d3
Merge branch 'master' into fix/batch_size
carmocca Nov 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed


-
- Fixed an issue that caused Lightning to extract the batch size even though it was set by the user in `LightningModule.log` ([#10408](https://github.com/PyTorchLightning/pytorch-lightning/pull/10408))


-
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

self.batch_progress.increment_ready()

# cache the batch size value to avoid extracting it again after the batch loop runs as the value will be
# different if tbptt is enabled
batch_size = self.trainer.logger_connector.on_batch_start(batch_idx, batch)
self.trainer.logger_connector.on_batch_start(batch_idx, batch)

if batch is None:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
Expand Down Expand Up @@ -194,8 +192,6 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, batch_idx)

self.trainer._results.batch_size = batch_size
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

self.batch_progress.increment_processed()

# update non-plateau LR schedulers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:

def on_train_split_start(self, split_idx: int, split_batch: Any) -> None:
self._split_idx = split_idx
self.on_new_batch(split_batch)

def update_train_step_metrics(self) -> None:
if self.trainer.fit_loop._should_accumulate() and self.trainer.lightning_module.automatic_optimization:
Expand Down Expand Up @@ -253,28 +252,23 @@ def _log_gpus_metrics(self) -> None:
Utilities and properties
"""

def on_new_batch(self, batch: Any) -> int:
# when the user requests `dataloader_iter`, we can't track the batch_size
# and this is left to user responsibility.
if not isinstance(batch, pl.utilities.fetching.StepFuncDataLoaderIter):
assert self.trainer._results is not None
return self.trainer._results.extract_batch_size(batch)
return 1

def on_epoch_start(self) -> None:
self._epoch_end_reached = False

def on_batch_start(self, batch_idx: int, batch: Any) -> int:
def on_batch_start(self, batch_idx: int, batch: Any) -> None:
self._batch_idx = batch_idx
self._epoch_end_reached = False
return self.on_new_batch(batch)
assert self.trainer._results is not None

# attach reference to the new batch and remove the cached batch_size
self.trainer._results.current_batch = batch
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.trainer._results.current_batch_size = None
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

def epoch_end_reached(self) -> None:
self._epoch_end_reached = True
self._batch_idx = None
self._split_idx = None
assert self.trainer._results is not None
self.trainer._results.batch_size = 1

def on_epoch_end(self) -> None:
assert self._epoch_end_reached
Expand All @@ -290,6 +284,11 @@ def on_batch_end(self) -> None:
self._progress_bar_metrics.update(metrics["pbar"])
self._callback_metrics.update(metrics["callback"])
self._logged_metrics.update(metrics["log"])
assert self.trainer._results is not None
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

# drop the reference to current batch and batch_size
self.trainer._results.current_batch = None
self.trainer._results.current_batch_size = None

def should_reset_tensors(self, fx: str) -> bool:
is_different_fx = self._current_fx != fx
Expand Down
67 changes: 43 additions & 24 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing_extensions import TypedDict

from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device
from pytorch_lightning.utilities.data import extract_batch_size
Expand Down Expand Up @@ -211,7 +212,7 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
if self.meta.is_mean_reduction:
self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float), dist_reduce_fx=torch.sum)

def update(self, value: _IN_METRIC, batch_size: torch.Tensor) -> None:
def update(self, value: _IN_METRIC, batch_size: int) -> None:
if self.is_tensor:
value = value.float()
if self.meta.on_step:
Expand Down Expand Up @@ -250,7 +251,7 @@ def reset(self) -> None:
self.value.reset()
self.has_reset = True

def forward(self, value: _IN_METRIC, batch_size: torch.Tensor) -> None:
def forward(self, value: _IN_METRIC, batch_size: int) -> None:
if self.meta.enable_graph:
with torch.no_grad():
self.update(value, batch_size)
Expand Down Expand Up @@ -376,7 +377,8 @@ class ResultCollection(dict):
def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None:
super().__init__()
self.training = training
self._batch_size = torch.tensor(1, device=device)
self._current_batch = None
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self._current_batch_size: Optional[int] = None
self.device: Optional[Union[str, torch.device]] = device

@property
Expand All @@ -391,13 +393,41 @@ def append_fn(v: ResultMetric) -> None:
return o

@property
def batch_size(self) -> torch.Tensor:
# performance: cache the `batch_size` tensor instead of re-creating it
return self._batch_size
def current_batch_size(self) -> Optional[int]:
return self._current_batch_size

@batch_size.setter
def batch_size(self, value: int) -> None:
self._batch_size = torch.tensor(value, device=self.device)
@current_batch_size.setter
def current_batch_size(self, val: Optional[int]) -> None:
self._current_batch_size = val

@property
def current_batch(self) -> Any:
return self._current_batch

@current_batch.setter
def current_batch(self, data: Any) -> None:
self._current_batch = data

def _extract_batch_size(self, batch_size: Optional[int], on_epoch: bool, fx: str, meta: _Metadata) -> int:
# check if we have extracted the batch size already
if batch_size is None:
batch_size = batch_size or self.current_batch_size

# extract batch size if it is None and whenever it is required
if batch_size is None:
fx_validate = _FxValidator.functions.get(fx.split(".")[0])
if on_epoch and fx_validate is not None and (True in fx_validate["on_step"]) and meta.is_mean_reduction:
try:
batch_size = extract_batch_size(self.current_batch)
except RecursionError:
batch_size = 1

# cache batch_size
self.current_batch_size = batch_size
else:
batch_size = 1

return batch_size

def log(
self,
Expand Down Expand Up @@ -458,10 +488,8 @@ def log(
f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed"
)

if batch_size is not None:
self.batch_size = batch_size

self.update_metrics(key, value)
batch_size = self._extract_batch_size(batch_size, on_epoch, fx, meta)
self.update_metrics(key, value, batch_size)

def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None:
"""Create one ResultMetric object per value.
Expand All @@ -478,10 +506,10 @@ def fn(v: _IN_METRIC) -> ResultMetric:
value = ResultMetricCollection(value)
self[key] = value

def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None:
def update_metrics(self, key: str, value: _METRIC_COLLECTION, batch_size: int) -> None:
def fn(result_metric: ResultMetric, v: ResultMetric) -> None:
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
result_metric.forward(v.to(self.device), self.batch_size)
result_metric.forward(v.to(self.device), batch_size)
result_metric.has_reset = False

apply_to_collections(self[key], value, ResultMetric, fn)
Expand Down Expand Up @@ -575,19 +603,10 @@ def fn(item: ResultMetric) -> None:

apply_to_collection(self, ResultMetric, fn)

def extract_batch_size(self, batch: Any) -> int:
try:
batch_size = extract_batch_size(batch)
except RecursionError:
batch_size = 1
self.batch_size = batch_size # the setter converts it to `Tensor`
return batch_size

def to(self, *args: Any, **kwargs: Any) -> "ResultCollection":
"""Move all data to the given device."""
self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs))

self._batch_size = self._batch_size.to(*args, **kwargs)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if "device" in kwargs:
self.device = kwargs["device"]
return self
Expand Down
13 changes: 8 additions & 5 deletions tests/loops/test_loop_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from unittest.mock import Mock

import pytest
import torch

from pytorch_lightning.loops import FitLoop
from pytorch_lightning.trainer.trainer import Trainer
Expand Down Expand Up @@ -80,14 +79,16 @@ def test_loops_state_dict_structure():
"is_last_batch": False,
},
"epoch_loop.val_loop._results": {
"_current_batch": None,
"_current_batch_size": None,
"training": False,
"_batch_size": torch.tensor(1),
"device": None,
"items": {},
},
"epoch_loop._results": {
"_current_batch": None,
"_current_batch_size": None,
"training": True,
"_batch_size": torch.tensor(1),
"device": None,
"items": {},
},
Expand All @@ -106,8 +107,9 @@ def test_loops_state_dict_structure():
"is_last_batch": False,
},
"_results": {
"_current_batch": None,
"_current_batch_size": None,
"training": False,
"_batch_size": torch.tensor(1),
"device": None,
"items": {},
},
Expand All @@ -122,8 +124,9 @@ def test_loops_state_dict_structure():
"is_last_batch": False,
},
"_results": {
"_current_batch": None,
"_current_batch_size": None,
"training": False,
"_batch_size": torch.tensor(1),
"device": None,
"items": {},
},
Expand Down
39 changes: 35 additions & 4 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,19 +715,15 @@ def on_validation_epoch_end(self):
assert all(v == 3 for v in self.trainer.callback_metrics.values())

def on_train_batch_start(self, batch, batch_idx):
assert self.trainer._results.batch_size == 2
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.log("on_train_batch_start", 1.0, reduce_fx="sum")

def on_train_batch_end(self, outputs, batch, batch_idx):
assert self.trainer._results.batch_size == 2
self.log("on_train_batch_end", 1.0, reduce_fx="sum")

def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
assert self.trainer._results.batch_size == 2
self.log("on_validation_batch_start", 1.0, reduce_fx="sum")

def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
assert self.trainer._results.batch_size == 2
self.log("on_validation_batch_end", 1.0, reduce_fx="sum")

def training_epoch_end(self, *_) -> None:
Expand All @@ -749,3 +745,38 @@ def validation_epoch_end(self, *_) -> None:
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)


def test_no_batch_size_extraction_with_specifying_explictly(tmpdir):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
batch_size = BoringModel().train_dataloader().batch_size + 10
fast_dev_run = 2
log_val = 7.0

class CustomBoringModel(BoringModel):
def on_before_batch_transfer(self, batch, *args, **kwargs):
# This is an ambiguous batch which have multiple potential batch sizes
if self.trainer.training:
batch = {"batch1": torch.randn(batch_size + 10, 10), "batch2": batch}
return batch

def training_step(self, batch, batch_idx):
self.log("step_log_val", log_val, on_epoch=False)
self.log("epoch_log_val", log_val, batch_size=batch_size, on_step=False, on_epoch=True)
self.log("epoch_sum_log_val", log_val, on_epoch=True, reduce_fx="sum")
return super().training_step(batch["batch2"], batch_idx)

def training_epoch_end(self, *args, **kwargs):
results = self.trainer._results
assert results["training_step.step_log_val"].value == log_val
assert results["training_step.step_log_val"].cumulated_batch_size == 0
assert results["training_step.epoch_log_val"].value == log_val * batch_size * fast_dev_run
assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size * fast_dev_run
assert results["training_step.epoch_sum_log_val"].value == log_val * fast_dev_run

model = CustomBoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=fast_dev_run)

with pytest.warns(None) as record:
trainer.fit(model)

assert not any("Trying to infer the `batch_size`" in warn.message.args[0] for warn in record.list)