Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix an issue that caused Lightning to extract the batch size even though it was set by the user in LightningModule.log #10408

Merged
merged 41 commits into from
Nov 19, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
5e10b93
extract batch size only when it's required
rohitgr7 Nov 8, 2021
8941e32
revamp log
rohitgr7 Nov 8, 2021
51a7d36
fix
rohitgr7 Nov 8, 2021
c2a3626
fix
rohitgr7 Nov 8, 2021
be5882c
update tests
rohitgr7 Nov 8, 2021
bba8055
chlog
rohitgr7 Nov 8, 2021
308f9c0
fix
rohitgr7 Nov 9, 2021
04314b8
add test
rohitgr7 Nov 9, 2021
b1de03e
Merge branch 'master' into fix/batch_size
rohitgr7 Nov 9, 2021
46a8861
mypy
rohitgr7 Nov 9, 2021
9b28e44
mypy
rohitgr7 Nov 9, 2021
0772967
deref current batch
rohitgr7 Nov 9, 2021
c3567ff
mypy
rohitgr7 Nov 9, 2021
2ec9885
Merge branch 'master' into fix/batch_size
rohitgr7 Nov 9, 2021
1484284
rev
rohitgr7 Nov 9, 2021
4363e69
cache batch size
rohitgr7 Nov 10, 2021
94c5749
Merge branch 'master' into fix/batch_size
rohitgr7 Nov 10, 2021
cb9faae
update test
rohitgr7 Nov 10, 2021
61b9483
mypy
rohitgr7 Nov 10, 2021
35c37ac
mypy
rohitgr7 Nov 10, 2021
46c22be
move to resultcollection
rohitgr7 Nov 10, 2021
9b25ecf
mypy
rohitgr7 Nov 10, 2021
b7a2296
update logic
rohitgr7 Nov 10, 2021
55a189e
Apply suggestions from code review
rohitgr7 Nov 18, 2021
e89b6d2
Merge remote-tracking branch 'origin/master' into fix/batch_size
rohitgr7 Nov 18, 2021
c9a8543
chlog
rohitgr7 Nov 18, 2021
a053dd1
update on comments
rohitgr7 Nov 18, 2021
a61dad1
Use our utilities
carmocca Nov 19, 2021
b9d2a56
Whitespace
carmocca Nov 19, 2021
f87d215
Remove unnecesary properties
carmocca Nov 19, 2021
0074b81
Remove current prefix
carmocca Nov 19, 2021
cb9de15
Simplify arguments
carmocca Nov 19, 2021
92cd293
Avoid indentations
carmocca Nov 19, 2021
6c9522d
Cache only if succesfully extracted
carmocca Nov 19, 2021
2979772
Merge branch 'master' into fix/batch_size
carmocca Nov 19, 2021
bb6f3c4
minor updates
rohitgr7 Nov 19, 2021
250defa
Simplify check
carmocca Nov 19, 2021
1ee9b17
Remove silly comment
carmocca Nov 19, 2021
b806ab0
Merge branch 'master' into fix/batch_size
carmocca Nov 19, 2021
560da62
mypy fix
carmocca Nov 19, 2021
4c859d3
Merge branch 'master' into fix/batch_size
carmocca Nov 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the logging with `on_step=True` in epoch-level hooks causing unintended side-effects. Logging with `on_step=True` in epoch-level hooks will now correctly raise an error ([#10409](https://github.com/PyTorchLightning/pytorch-lightning/pull/10409))


-


-

- Fixed the call to `extract_batch_size` only when its required ([#10408](https://github.com/PyTorchLightning/pytorch-lightning/pull/10408))

-


## [1.5.0] - 2021-11-02
Expand Down
248 changes: 133 additions & 115 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.data import extract_batch_size
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import get_model_size_mb
Expand Down Expand Up @@ -358,122 +359,25 @@ def log(
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
would produce a deadlock as not all processes would perform this log call.
"""
if tbptt_reduce_fx is not None:
rank_zero_deprecation(
"`self.log(tbptt_reduce_fx=...)` is no longer supported. The flag will be removed in v1.6."
" Please, open a discussion explaining your use-case in"
" `https://github.com/PyTorchLightning/pytorch-lightning/discussions`"
)
if tbptt_pad_token is not None:
rank_zero_deprecation(
"`self.log(tbptt_pad_token=...)` is no longer supported. The flag will be removed in v1.6."
" Please, open a discussion explaining your use-case in"
" `https://github.com/PyTorchLightning/pytorch-lightning/discussions`"
)
if sync_dist_op is not None:
rank_zero_deprecation(
f"`self.log(sync_dist_op='{sync_dist_op}')` is deprecated and will be removed in v.1.6."
f" Use `self.log(reduce_fx={sync_dist_op})` instead."
)
if reduce_fx == "default":
reduce_fx = sync_dist_op
elif reduce_fx == "default":
reduce_fx = "mean"

# check for invalid values
apply_to_collection(value, dict, self.__check_not_nested, name)
apply_to_collection(
value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor, dict)
)

# set the default depending on the fx_name
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)

if self.trainer is None:
# not an error to support testing the `*_step` methods without a `Trainer` reference
rank_zero_warn(
"You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet."
" This is most likely because the model hasn't been passed to the `Trainer`"
)
return
results = self.trainer._results
if results is None:
raise MisconfigurationException(
"You are trying to `self.log()` but the loop `ResultCollection` is not registered"
" yet. This is most likely because you are trying to log in a `predict` hook,"
" but it doesn't support logging"
)
if self._current_fx_name is None:
raise MisconfigurationException(
"You are trying to `self.log()` but it is not managed by the `Trainer` control flow"
)
_FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)

# make sure user doesn't introduce logic for multi-dataloaders
if "/dataloader_idx_" in name:
raise MisconfigurationException(
f"You called `self.log` with the key `{name}`"
" but it should not contain information about `dataloader_idx`"
)

value = apply_to_collection(value, numbers.Number, self.__to_tensor)

if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name):
# if we started a new epoch (running it's first batch) the hook name has changed
# reset any tensors for the new hook name
results.reset(metrics=False, fx=self._current_fx_name)

if metric_attribute is None and isinstance(value, Metric):
if self._metric_attributes is None:
# compute once
self._metric_attributes = {
id(module): name for name, module in self.named_modules() if isinstance(module, Metric)
}
if not self._metric_attributes:
raise MisconfigurationException(
"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
" You can fix this by setting an attribute for the metric in your `LightningModule`."
)
# try to find the passed metric in the LightningModule
metric_attribute = self._metric_attributes.get(id(value), None)
if metric_attribute is None:
raise MisconfigurationException(
"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name` is one"
f" of {list(self._metric_attributes.values())}"
)

if (
self.trainer.training
and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True)
and batch_size is None
):
raise MisconfigurationException(
"With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided."
)

results.log(
self._current_fx_name,
name,
value,
self.log_dict(
dictionary={name: value},
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
batch_size=batch_size,
sync_dist=sync_dist and distributed_available(),
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp,
sync_dist=sync_dist,
sync_dist_group=sync_dist_group,
sync_dist_op=sync_dist_op,
tbptt_pad_token=tbptt_pad_token,
tbptt_reduce_fx=tbptt_reduce_fx,
add_dataloader_idx=add_dataloader_idx,
batch_size=batch_size,
metric_attribute=metric_attribute,
rank_zero_only=rank_zero_only,
)

self.trainer.logger_connector._current_fx = self._current_fx_name

def log_dict(
self,
dictionary: Mapping[str, _METRIC_COLLECTION],
Expand All @@ -490,6 +394,7 @@ def log_dict(
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
batch_size: Optional[int] = None,
metric_attribute: Optional[str] = None,
rank_zero_only: Optional[bool] = None,
) -> None:
"""Log a dictionary of values at once.
Expand All @@ -516,29 +421,142 @@ def log_dict(
each dataloader to not mix values
batch_size: Current batch_size. This will be directly inferred from the loaded batch,
but some data structures might need to explicitly provide it.
metric_attribute: To restore the metric state, Lightning requires the reference of the
:class:`torchmetrics.Metric` in your model. This is found automatically if it is a model attribute.
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
would produce a deadlock as not all processes would perform this log call.
"""
for k, v in dictionary.items():
self.log(
name=k,
value=v,
if self.trainer is None:
# not an error to support testing the `*_step` methods without a `Trainer` reference
rank_zero_warn(
"You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet."
" This is most likely because the model hasn't been passed to the `Trainer`"
)
return

if (
self.trainer.training
and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True)
and batch_size is None
):
raise MisconfigurationException(
"With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided."
)

results = self.trainer._results
if results is None:
raise MisconfigurationException(
"You are trying to `self.log()` but the loop `ResultCollection` is not registered"
" yet. This is most likely because you are trying to log in a `predict` hook,"
" but it doesn't support logging"
)
if self._current_fx_name is None:
raise MisconfigurationException(
"You are trying to `self.log()` but it is not managed by the `Trainer` control flow"
)

# set the default depending on the fx_name
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)
_FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)

if isinstance(reduce_fx, str):
reduce_fx = reduce_fx.lower()

if tbptt_reduce_fx is not None:
rank_zero_deprecation(
"`self.log(tbptt_reduce_fx=...)` is no longer supported. The flag will be removed in v1.6."
" Please, open a discussion explaining your use-case in"
" `https://github.com/PyTorchLightning/pytorch-lightning/discussions`"
)
if tbptt_pad_token is not None:
rank_zero_deprecation(
"`self.log(tbptt_pad_token=...)` is no longer supported. The flag will be removed in v1.6."
" Please, open a discussion explaining your use-case in"
" `https://github.com/PyTorchLightning/pytorch-lightning/discussions`"
)
if sync_dist_op is not None:
rank_zero_deprecation(
f"`self.log(sync_dist_op='{sync_dist_op}')` is deprecated and will be removed in v.1.6."
f" Use `self.log(reduce_fx={sync_dist_op})` instead."
)
if reduce_fx == "default":
reduce_fx = sync_dist_op
elif reduce_fx == "default":
reduce_fx = "mean"

if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name):
# if we started a new epoch (running it's first batch) the hook name has changed
# reset any tensors for the new hook name
results.reset(metrics=False, fx=self._current_fx_name)

# extract batch size if it's None whenever it's required
if batch_size is None:
if (
on_epoch
and True in _FxValidator.functions[self._current_fx_name]["on_step"]
and reduce_fx in ("mean", "avg")
):
batch_size = extract_batch_size(self.trainer._results.current_batch)
else:
batch_size = 1

for name, value in dictionary.items():
# make sure user doesn't introduce logic for multi-dataloaders
if "/dataloader_idx_" in name:
raise MisconfigurationException(
f"You called `self.log` with the key `{name!r}`"
" but it should not contain information about `dataloader_idx`"
)

# check for invalid values
apply_to_collection(value, dict, self.__check_not_nested, name)
apply_to_collection(
value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor, dict)
)
value = apply_to_collection(value, numbers.Number, self.__to_tensor)

if metric_attribute is None and isinstance(value, Metric):
if self._metric_attributes is None:
# compute once
self._metric_attributes = {
id(module): name for name, module in self.named_modules() if isinstance(module, Metric)
}
if not self._metric_attributes:
raise MisconfigurationException(
"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
" You can fix this by setting an attribute for the metric in your `LightningModule`."
)
# try to find the passed metric in the LightningModule
metric_attribute = self._metric_attributes.get(id(value), None)
if metric_attribute is None:
raise MisconfigurationException(
"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name`"
f" is one of {list(self._metric_attributes.values())}."
)

results.log(
self._current_fx_name,
name,
value,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
sync_dist=sync_dist,
sync_dist_group=sync_dist_group,
sync_dist_op=sync_dist_op,
tbptt_pad_token=tbptt_pad_token,
tbptt_reduce_fx=tbptt_reduce_fx,
add_dataloader_idx=add_dataloader_idx,
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
batch_size=batch_size,
sync_dist=sync_dist and distributed_available(),
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp,
sync_dist_group=sync_dist_group,
metric_attribute=metric_attribute,
rank_zero_only=rank_zero_only,
)

self.trainer.logger_connector._current_fx = self._current_fx_name

@staticmethod
def __check_not_nested(value: dict, name: str) -> dict:
# self-imposed restriction. for simplicity
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

self.batch_progress.increment_ready()

# cache the batch size value to avoid extracting it again after the batch loop runs as the value will be
# different if tbptt is enabled
batch_size = self.trainer.logger_connector.on_batch_start(batch_idx, batch)
self.trainer.logger_connector.on_batch_start(batch_idx, batch)

if batch is None:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
Expand Down Expand Up @@ -194,8 +192,6 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, batch_idx)

self.trainer._results.batch_size = batch_size

self.batch_progress.increment_processed()

# update non-plateau LR schedulers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:

def on_train_split_start(self, split_idx: int, split_batch: Any) -> None:
self._split_idx = split_idx
self.on_new_batch(split_batch)

def update_train_step_metrics(self) -> None:
if self.trainer.fit_loop._should_accumulate() and self.trainer.lightning_module.automatic_optimization:
Expand Down Expand Up @@ -253,28 +252,20 @@ def _log_gpus_metrics(self) -> None:
Utilities and properties
"""

def on_new_batch(self, batch: Any) -> int:
# when the user requests `dataloader_iter`, we can't track the batch_size
# and this is left to user responsibility.
if not isinstance(batch, pl.utilities.fetching.StepFuncDataLoaderIter):
assert self.trainer._results is not None
return self.trainer._results.extract_batch_size(batch)
return 1

def on_epoch_start(self) -> None:
self._epoch_end_reached = False

def on_batch_start(self, batch_idx: int, batch: Any) -> int:
def on_batch_start(self, batch_idx: int, batch: Any) -> None:
self._batch_idx = batch_idx
self._epoch_end_reached = False
return self.on_new_batch(batch)
assert self.trainer._results is not None
self.trainer._results.current_batch = batch

def epoch_end_reached(self) -> None:
self._epoch_end_reached = True
self._batch_idx = None
self._split_idx = None
assert self.trainer._results is not None
self.trainer._results.batch_size = 1

def on_epoch_end(self) -> None:
assert self._epoch_end_reached
Expand All @@ -290,6 +281,8 @@ def on_batch_end(self) -> None:
self._progress_bar_metrics.update(metrics["pbar"])
self._callback_metrics.update(metrics["callback"])
self._logged_metrics.update(metrics["log"])
assert self.trainer._results is not None
self.trainer._results.current_batch = None

def should_reset_tensors(self, fx: str) -> bool:
is_different_fx = self._current_fx != fx
Expand Down
Loading