Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] logging refactors 1/n #4439

Merged
merged 28 commits into from
Nov 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e349141
introducing new logging object
tchaton Oct 30, 2020
8417a15
typo
tchaton Oct 30, 2020
e21b249
typo
tchaton Oct 30, 2020
edd3818
Merge branch 'master' into feat/epoch_loop_result_objs
tchaton Oct 30, 2020
f9022d7
Update pytorch_lightning/trainer/logging.py
tchaton Oct 30, 2020
f7fd584
Update pytorch_lightning/trainer/logging.py
tchaton Oct 30, 2020
de910e1
update on comments
tchaton Oct 30, 2020
eaf349b
update on comments
tchaton Oct 30, 2020
b81aecf
add more doctstring
tchaton Oct 30, 2020
e80f2e8
Merge branch 'master' into feat/epoch_loop_result_objs
tchaton Oct 30, 2020
694ded2
Update pytorch_lightning/core/lightning.py
tchaton Oct 30, 2020
cab95b8
resolve on comments
tchaton Oct 30, 2020
664af72
Merge branch 'feat/epoch_loop_result_objs' of https://github.com/PyTo…
tchaton Oct 30, 2020
4190cf5
Merge branch 'master' into feat/epoch_loop_result_objs
tchaton Nov 2, 2020
0fbe5ea
solve pyright
tchaton Nov 2, 2020
46b76f6
Merge branch 'feat/epoch_loop_result_objs' of https://github.com/PyTo…
tchaton Nov 2, 2020
1c47c8a
Merge branch 'master' into feat/epoch_loop_result_objs
tchaton Nov 2, 2020
49a491a
Update pytorch_lightning/trainer/connectors/logger_connector/epoch_re…
tchaton Nov 2, 2020
ac09913
update on comments
tchaton Nov 2, 2020
f8f06cf
Merge branch 'feat/epoch_loop_result_objs' of https://github.com/PyTo…
tchaton Nov 2, 2020
a9f6e56
Merge branch 'master' into feat/epoch_loop_result_objs
tchaton Nov 2, 2020
e604313
Update pytorch_lightning/trainer/connectors/logger_connector/epoch_re…
tchaton Nov 2, 2020
14f92b0
update on comments
tchaton Nov 2, 2020
92c8a3f
Merge branch 'feat/epoch_loop_result_objs' of https://github.com/PyTo…
tchaton Nov 2, 2020
55c7a2c
Merge branch 'master' into feat/epoch_loop_result_objs
SeanNaren Nov 2, 2020
05baa28
Merge branch 'master' into feat/epoch_loop_result_objs
tchaton Nov 2, 2020
518929d
Merge branch 'master' into feat/epoch_loop_result_objs
tchaton Nov 2, 2020
80e825c
Merge branch 'master' into feat/epoch_loop_result_objs
tchaton Nov 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import collections
import copy
import inspect
import os
import re
import tempfile
from abc import ABC
from argparse import Namespace
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Mapping
Expand All @@ -28,16 +27,17 @@
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import rank_zero_warn, AMPType
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.parsing import (
AttributeDict,
collect_init_args,
get_init_args,
)
from pytorch_lightning.callbacks import Callback
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -111,6 +111,8 @@ def __init__(self, *args, **kwargs):
self._datamodule = None
self._results: Optional[Result] = None
self._current_fx_name = ''
self._current_hook_fx_name = None
self._current_dataloader_idx = None

def optimizers(self):
opts = self.trainer.optimizers
Expand Down Expand Up @@ -244,6 +246,18 @@ def log(
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)

if self._current_hook_fx_name is not None:
self.trainer.logger_connector.check_logging_in_callbacks(
self._current_hook_fx_name,
on_step=on_step,
on_epoch=on_epoch
)

# make sure user doesn't introduce logic for multi-dataloaders
if "/dataloader_idx_" in name:
raise MisconfigurationException(
f"Logged key: {name} should not contain information about dataloader_idx.")

self._results.log(
name,
value,
Expand All @@ -257,7 +271,8 @@ def log(
enable_graph,
sync_dist,
sync_dist_op,
sync_dist_group
sync_dist_group,
self._current_dataloader_idx,
)

def log_dict(
Expand Down Expand Up @@ -1277,11 +1292,11 @@ def tbptt_split_batch(self, batch, split_size):
batch_split = []
for i, x in enumerate(batch):
if isinstance(x, torch.Tensor):
split_x = x[:, t : t + split_size]
split_x = x[:, t: t + split_size]
elif isinstance(x, collections.Sequence):
split_x = [None] * len(x)
for batch_idx in range(len(x)):
split_x[batch_idx] = x[batch_idx][t : t + split_size]
split_x[batch_idx] = x[batch_idx][t: t + split_size]

batch_split.append(split_x)

Expand Down
77 changes: 56 additions & 21 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def log(
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
dataloader_idx: Optional[int] = None,
):
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
Expand All @@ -144,6 +145,7 @@ def log(

# set step version
step_name = f'{name}_step'

self.__set_meta(
step_name,
value,
Expand All @@ -154,12 +156,15 @@ def log(
reduce_fx=reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=False
forked=False,
dataloader_idx=dataloader_idx,
)

self.__setitem__(step_name, value)

# set epoch version
epoch_name = f'{name}_epoch'

self.__set_meta(
epoch_name,
value,
Expand All @@ -170,7 +175,8 @@ def log(
reduce_fx=reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=False
forked=False,
dataloader_idx=dataloader_idx,
)
self.__setitem__(epoch_name, value)

Expand All @@ -185,7 +191,8 @@ def log(
reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=was_forked
forked=was_forked,
dataloader_idx=dataloader_idx,
)

# set the value
Expand All @@ -202,7 +209,8 @@ def __set_meta(
reduce_fx: Callable,
tbptt_pad_token: int,
tbptt_reduce_fx: Callable,
forked: bool
forked: bool,
dataloader_idx: Union[int, None]
):
# set the meta for the item
meta_value = value
Expand All @@ -215,7 +223,8 @@ def __set_meta(
value=meta_value,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=forked
forked=forked,
dataloader_idx=dataloader_idx,
)

self['meta'][name] = meta
Expand All @@ -225,13 +234,22 @@ def __set_meta(
_internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch)

def track_batch_size(self, batch):
batch_size = Result.extract_batch_size(batch)
Result.attach_batch_size(batch_size, self)

@staticmethod
def extract_batch_size(batch):
try:
batch_size = Result.unpack_batch_size(batch)
except RecursionError as re:
batch_size = 1
return batch_size

meta = self['meta']
meta['_internal']['batch_sizes'].append(batch_size)
@staticmethod
def attach_batch_size(batch_size: Union[int, None], result: 'Result') -> None:
if batch_size is not None:
meta = result['meta']
meta['_internal']['batch_sizes'].append(batch_size)

def get_batch_sizes(self):
meta = self['meta']
Expand All @@ -242,7 +260,12 @@ def get_callback_metrics(self) -> dict:

return result

def get_batch_log_metrics(self, include_forked_originals=True) -> dict:
def _add_dataloader_idx(self, k: str, dataloader_idx: Union[int, None], add_dataloader_idx: bool) -> str:
if dataloader_idx is not None and add_dataloader_idx:
return f"{k}/dataloader_idx_{dataloader_idx}"
return k

def get_batch_log_metrics(self, include_forked_originals=True, add_dataloader_idx=False) -> dict:
"""
Gets the metrics to log at the end of the batch step

Expand All @@ -257,15 +280,17 @@ def get_batch_log_metrics(self, include_forked_originals=True) -> dict:
if options['forked'] and not include_forked_originals:
continue

dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['logger'] and options['on_step']:
if isinstance(self[k], Metric):
result[k] = self[k]._forward_cache.detach()
result[dl_key] = self[k]._forward_cache.detach()
else:
result[k] = self[k]
result[dl_key] = self[k]

return result

def get_epoch_log_metrics(self) -> dict:
def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict:
"""
Gets the metrics to log at the end of epoch
"""
Expand All @@ -279,19 +304,21 @@ def get_epoch_log_metrics(self) -> dict:
if options['forked']:
continue

dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['logger'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[k] = self[k].compute().detach()
result[dl_key] = self[k].compute().detach()
else:
result[k] = self[k]
result[dl_key] = self[k]

if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
# compute metric on epoch anyway so state does not accumulate
self[k].compute()

return result

def get_epoch_pbar_metrics(self):
def get_epoch_pbar_metrics(self, add_dataloader_idx=False):
"""
Gets the metrics to log at the end of epoch
"""
Expand All @@ -305,19 +332,21 @@ def get_epoch_pbar_metrics(self):
if options['forked']:
continue

dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['prog_bar'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[k] = self[k].compute().detach()
result[dl_key] = self[k].compute().detach()
else:
result[k] = self[k]
result[dl_key] = self[k]

if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
# compute metric on epoch anyway so state does not accumulate
self[k].compute()

return result

def get_forked_metrics(self):
def get_forked_metrics(self, add_dataloader_idx=False):
"""
Gets the metrics to log at the end of epoch
"""
Expand All @@ -328,12 +357,14 @@ def get_forked_metrics(self):
if k == '_internal':
continue

dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['forked']:
result[k] = self[k]
result[dl_key] = self[k]

return result

def get_batch_pbar_metrics(self, include_forked_originals=True):
def get_batch_pbar_metrics(self, include_forked_originals=True, add_dataloader_idx=False):
"""
Gets the metrics to log at the end of the batch step
"""
Expand All @@ -347,11 +378,13 @@ def get_batch_pbar_metrics(self, include_forked_originals=True):
if options['forked'] and not include_forked_originals:
continue

dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['prog_bar'] and options['on_step']:
if isinstance(self[k], Metric):
result[k] = self[k]._forward_cache
result[dl_key] = self[k]._forward_cache
else:
result[k] = self[k]
result[dl_key] = self[k]

return result

Expand Down Expand Up @@ -473,6 +506,8 @@ def reduce_on_epoch_end(cls, outputs):
if option['on_epoch']:
fx = option['reduce_fx']
if fx == torch.mean:
if isinstance(result[k], list):
result[k] = torch.tensor(result[k]).float()
try:
reduced_val = weighted_mean(result[k], batch_sizes)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pytorch_lightning.trainer.connectors.logger_connector.logger_connector import LoggerConnector
Loading