From e349141e03d4af1181713c534c0cc6db37313f33 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 30 Oct 2020 10:03:51 +0000 Subject: [PATCH 01/15] introducing new logging object --- pytorch_lightning/core/lightning.py | 28 +- pytorch_lightning/core/step_result.py | 77 ++- .../connectors/logger_connector/__init__.py | 3 + .../callback_hook_validator.py | 220 ++++++++ .../logger_connector/epoch_loop_result.py | 470 ++++++++++++++++++ .../logger_connector.py | 79 ++- pytorch_lightning/trainer/evaluation_loop.py | 15 +- pytorch_lightning/trainer/logging.py | 4 +- pytorch_lightning/trainer/trainer.py | 21 +- pytorch_lightning/trainer/training_loop.py | 57 ++- .../trainer/logging/test_logger_connector.py | 249 ++++++++++ 11 files changed, 1163 insertions(+), 60 deletions(-) create mode 100644 pytorch_lightning/trainer/connectors/logger_connector/__init__.py create mode 100644 pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py create mode 100644 pytorch_lightning/trainer/connectors/logger_connector/epoch_loop_result.py rename pytorch_lightning/trainer/connectors/{ => logger_connector}/logger_connector.py (86%) create mode 100644 tests/trainer/logging/test_logger_connector.py diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 22d63d0a03a74..d0c7cf1fdb465 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -11,13 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os +import tempfile import collections import copy import inspect -import os import re -import tempfile from abc import ABC from argparse import Namespace from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Mapping @@ -28,16 +27,17 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO +from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities import rank_zero_warn, AMPType from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities.parsing import ( AttributeDict, collect_init_args, get_init_args, ) +from pytorch_lightning.callbacks import Callback from torch import ScriptModule, Tensor from torch.nn import Module from torch.optim.optimizer import Optimizer @@ -111,6 +111,8 @@ def __init__(self, *args, **kwargs): self._datamodule = None self._results: Optional[Result] = None self._current_fx_name = '' + self._current_hook_fx_name = '' + self._current_dataloader_idx = None def optimizers(self): opts = self.trainer.optimizers @@ -244,6 +246,17 @@ def log( on_step = self.__auto_choose_log_on_step(on_step) on_epoch = self.__auto_choose_log_on_epoch(on_epoch) + if self._current_hook_fx_name != '': + self.trainer.logger_connector\ + .check_logging_in_callbacks(self._current_hook_fx_name, + on_step=on_step, + on_epoch=on_epoch) + + # make sure user doesn't introduce logic for multi-dataloaders + if "/dataloader_idx_" in name: + raise MisconfigurationException( + f"Logged key: {name} should not contain information about dataloader_idx.") + self._results.log( name, value, @@ -257,7 +270,8 @@ def log( enable_graph, sync_dist, sync_dist_op, - sync_dist_group + sync_dist_group, + self._current_dataloader_idx, ) def log_dict( @@ -1278,11 +1292,11 @@ def tbptt_split_batch(self, batch, split_size): batch_split = [] for i, x in enumerate(batch): if isinstance(x, torch.Tensor): - split_x = x[:, t : t + split_size] + split_x = x[:, t: t + split_size] elif isinstance(x, collections.Sequence): split_x = [None] * len(x) for batch_idx in range(len(x)): - split_x[batch_idx] = x[batch_idx][t : t + split_size] + split_x[batch_idx] = x[batch_idx][t: t + split_size] batch_split.append(split_x) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 650c1876d0cd0..f86519b5a0ad9 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -124,6 +124,7 @@ def log( sync_dist: bool = False, sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, + dataloader_idx: Optional[int] = None, ): # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): @@ -144,6 +145,7 @@ def log( # set step version step_name = f'{name}_step' + self.__set_meta( step_name, value, @@ -154,12 +156,15 @@ def log( reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, - forked=False + forked=False, + dataloader_idx=dataloader_idx, ) + self.__setitem__(step_name, value) # set epoch version epoch_name = f'{name}_epoch' + self.__set_meta( epoch_name, value, @@ -170,7 +175,8 @@ def log( reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, - forked=False + forked=False, + dataloader_idx=dataloader_idx, ) self.__setitem__(epoch_name, value) @@ -185,7 +191,8 @@ def log( reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, - forked=was_forked + forked=was_forked, + dataloader_idx=dataloader_idx, ) # set the value @@ -202,7 +209,8 @@ def __set_meta( reduce_fx: Callable, tbptt_pad_token: int, tbptt_reduce_fx: Callable, - forked: bool + forked: bool, + dataloader_idx: Union[int, None] ): # set the meta for the item meta_value = value @@ -215,7 +223,8 @@ def __set_meta( value=meta_value, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, - forked=forked + forked=forked, + dataloader_idx=dataloader_idx, ) self['meta'][name] = meta @@ -225,13 +234,22 @@ def __set_meta( _internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch) def track_batch_size(self, batch): + batch_size = Result.extract_batch_size(batch) + Result.attach_batch_size(batch_size, self) + + @staticmethod + def extract_batch_size(batch): try: batch_size = Result.unpack_batch_size(batch) except RecursionError as re: batch_size = 1 + return batch_size - meta = self['meta'] - meta['_internal']['batch_sizes'].append(batch_size) + @staticmethod + def attach_batch_size(batch_size: Union[int, None], result: 'Result') -> None: + if batch_size is not None: + meta = result['meta'] + meta['_internal']['batch_sizes'].append(batch_size) def get_batch_sizes(self): meta = self['meta'] @@ -242,7 +260,12 @@ def get_callback_metrics(self) -> dict: return result - def get_batch_log_metrics(self, include_forked_originals=True) -> dict: + def _add_dataloader_idx(self, k: str, dataloader_idx: Union[int, None], add_dataloader_idx: bool) -> str: + if dataloader_idx is not None and add_dataloader_idx: + return f"{k}/dataloader_idx_{dataloader_idx}" + return k + + def get_batch_log_metrics(self, include_forked_originals=True, add_dataloader_idx=False) -> dict: """ Gets the metrics to log at the end of the batch step @@ -257,15 +280,17 @@ def get_batch_log_metrics(self, include_forked_originals=True) -> dict: if options['forked'] and not include_forked_originals: continue + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + if options['logger'] and options['on_step']: if isinstance(self[k], Metric): - result[k] = self[k]._forward_cache + result[dl_key] = self[k]._forward_cache else: - result[k] = self[k] + result[dl_key] = self[k] return result - def get_epoch_log_metrics(self) -> dict: + def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: """ Gets the metrics to log at the end of epoch """ @@ -279,11 +304,13 @@ def get_epoch_log_metrics(self) -> dict: if options['forked']: continue + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + if options['logger'] and options['on_epoch']: if isinstance(self[k], Metric): - result[k] = self[k].compute() + result[dl_key] = self[k].compute() else: - result[k] = self[k] + result[dl_key] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): # compute metric on epoch anyway so state does not accumulate @@ -291,7 +318,7 @@ def get_epoch_log_metrics(self) -> dict: return result - def get_epoch_pbar_metrics(self): + def get_epoch_pbar_metrics(self, add_dataloader_idx=False): """ Gets the metrics to log at the end of epoch """ @@ -305,11 +332,13 @@ def get_epoch_pbar_metrics(self): if options['forked']: continue + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + if options['prog_bar'] and options['on_epoch']: if isinstance(self[k], Metric): - result[k] = self[k].compute() + result[dl_key] = self[k].compute() else: - result[k] = self[k] + result[dl_key] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): # compute metric on epoch anyway so state does not accumulate @@ -317,7 +346,7 @@ def get_epoch_pbar_metrics(self): return result - def get_forked_metrics(self): + def get_forked_metrics(self, add_dataloader_idx=False): """ Gets the metrics to log at the end of epoch """ @@ -328,12 +357,14 @@ def get_forked_metrics(self): if k == '_internal': continue + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + if options['forked']: - result[k] = self[k] + result[dl_key] = self[k] return result - def get_batch_pbar_metrics(self, include_forked_originals=True): + def get_batch_pbar_metrics(self, include_forked_originals=True, add_dataloader_idx=False): """ Gets the metrics to log at the end of the batch step """ @@ -347,11 +378,13 @@ def get_batch_pbar_metrics(self, include_forked_originals=True): if options['forked'] and not include_forked_originals: continue + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + if options['prog_bar'] and options['on_step']: if isinstance(self[k], Metric): - result[k] = self[k]._forward_cache + result[dl_key] = self[k]._forward_cache else: - result[k] = self[k] + result[dl_key] = self[k] return result @@ -473,6 +506,8 @@ def reduce_on_epoch_end(cls, outputs): if option['on_epoch']: fx = option['reduce_fx'] if fx == torch.mean: + if isinstance(result[k], list): + result[k] = torch.tensor(result[k]).float() try: reduced_val = weighted_mean(result[k], batch_sizes) except Exception as e: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/__init__.py b/pytorch_lightning/trainer/connectors/logger_connector/__init__.py new file mode 100644 index 0000000000000..dd3b1c2e247e7 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector/__init__.py @@ -0,0 +1,3 @@ +from pytorch_lightning.trainer.connectors.logger_connector.epoch_loop_result import EpochResultStore, LoggerStages +from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator +from pytorch_lightning.trainer.connectors.logger_connector.logger_connector import LoggerConnector diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py new file mode 100644 index 0000000000000..3ce4b523545c3 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -0,0 +1,220 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class CallbackHookNameValidator: + + @staticmethod + def check_logging_in_callbacks(current_hook_fx_name: str = None, on_step: bool = None, + on_epoch: bool = None) -> None: + if current_hook_fx_name is None: + return + + internal_func = getattr(CallbackHookNameValidator, f"_{current_hook_fx_name}_log", None) + + if internal_func is None: + return + + current_callback_hook_auth_args = internal_func() + + if current_callback_hook_auth_args is not None: + m = "{} function supports only {} in {}. Provided {}" + if on_step not in current_callback_hook_auth_args["on_step"]: + msg = m.format(current_hook_fx_name, "on_step", current_callback_hook_auth_args["on_step"], on_step) + raise MisconfigurationException(msg) + + if on_epoch not in current_callback_hook_auth_args["on_epoch"]: + msg = m.format(current_hook_fx_name, "on_epoch", current_callback_hook_auth_args["on_epoch"], on_epoch) + raise MisconfigurationException(msg) + else: + raise MisconfigurationException( + f"{current_hook_fx_name} function doesn't support logging using self.log() yet." + ) + + @staticmethod + def _setup_log(): + """Called when fit or test begins""" + return None + + @staticmethod + def _teardown_log(): + """Called at the end of fit and test""" + return None + + @staticmethod + def _on_init_start_log(): + """Called when the trainer initialization begins, model has not yet been set.""" + return None + + @staticmethod + def _on_init_end_log(): + """Called when the trainer initialization ends, model has not yet been set.""" + return None + + @staticmethod + def _on_fit_start_log(): + """Called when the trainer initialization begins, model has not yet been set.""" + return None + + @staticmethod + def _on_fit_end_log(): + """Called when the trainer initialization begins, model has not yet been set.""" + return None + + @staticmethod + def _on_sanity_check_start_log(): + """Called when the validation sanity check starts.""" + return None + + @staticmethod + def _on_sanity_check_end_log(): + """Called when the validation sanity check ends.""" + return None + + @staticmethod + def _on_train_epoch_start_log(): + """Called when the epoch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_train_epoch_end_log(): + """Called when the epoch ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_epoch_start_log(): + """Called when the epoch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_epoch_end_log(): + """Called when the epoch ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_test_epoch_start_log(): + """Called when the epoch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_test_epoch_end_log(): + """Called when the epoch ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_epoch_start_log(): + """Called when the epoch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_epoch_end_log(): + """Called when the epoch ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_train_start_log(): + """Called when the train begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_train_end_log(): + """Called when the train ends.""" + return None + + @staticmethod + def _on_pretrain_routine_start_log(): + """Called when the train begins.""" + return None + + @staticmethod + def _on_pretrain_routine_end_log(): + """Called when the train ends.""" + return None + + @staticmethod + def _on_batch_start_log(): + """Called when the training batch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_batch_end_log(): + """Called when the training batch ends.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_train_batch_start_log(): + """Called when the training batch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_train_batch_end_log(): + """Called when the training batch ends.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_batch_start_log(): + """Called when the validation batch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_batch_end_log(): + """Called when the validation batch ends.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_test_batch_start_log(): + """Called when the test batch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_test_batch_end_log(): + """Called when the test batch ends.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_start_log(): + """Called when the validation loop begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_end_log(): + """Called when the validation loop ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_test_start_log(): + """Called when the test begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_test_end_log(): + """Called when the test ends.""" + return None + + @staticmethod + def _on_keyboard_interrupt_log(): + """Called when the training is interrupted by KeyboardInterrupt.""" + return None + + @staticmethod + def _on_save_checkpoint_log(): + """Called when saving a model checkpoint.""" + return None + + @staticmethod + def _on_load_checkpoint_log(): + """Called when loading a model checkpoint.""" + return None diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_loop_result.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_loop_result.py new file mode 100644 index 0000000000000..4b3af4f4af8da --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_loop_result.py @@ -0,0 +1,470 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from copy import deepcopy +from enum import Enum +from typing import Union, Tuple, Any + +from pytorch_lightning.core.step_result import Result + + +class LoggerStages(Enum): + TRAIN = "train" + VAL = "validation" + TEST = "test" + + +class HookResultStore: + """ + This class is used to hold all metrics logged during one callback or model hook. + Can be used for both training, val, test. + + Result objects will be stored in the following way. + + val and test: self._internals = {"dataloader_idx": [Result(), ..., Result()]} + training + - IF optimizer_idx and training_step_idx are set, + THEN self._internals = {"dataloader_idx": + {"optimizer_idx": + {"training_step_idx": + [Result(), Result()] + } + } + } + - ELSE self._internals = {"dataloader_idx": [Result(), ..., Result()]} + """ + + _types = ["list", "dict"] + + def __init__(self, fx_name): + self._fx_name = fx_name + self._internals = {} + self._internals_reduced = {} + self._internal_type = None + self.has_reduced = False + + def get_reduced_metrics(self): + return self._internals_reduced + + def add_dataloader_idx(self): + return len(self._internals) > 1 + + @property + def num_dataloaders(self): + return len(self._internals) + + def get_latest_from_dict(self, dl_idx): + num_opt_idx = len(self._internals[dl_idx]) - 1 + assert num_opt_idx >= 0 + num_opt_idx = str(num_opt_idx) + num_batch_idx = len(self._internals[dl_idx][num_opt_idx]) - 1 + batch_indexes = [*self._internals[dl_idx][num_opt_idx].keys()] + # sort them by increasing order + batch_indexes.sort(key=float) + assert num_batch_idx >= 0 + return self._internals[dl_idx][num_opt_idx][batch_indexes[-1]][-1] + + def check_dataloader_idx(self, result: Result) -> bool: + add_dataloader_idx = False + try: + if len(result.keys()) > 1: + random_key = [*result.keys()][-1] + add_dataloader_idx = result["meta"][random_key]["dataloader_idx"] is not None + return add_dataloader_idx + return add_dataloader_idx + except Exception: + return add_dataloader_idx + + def get_lastest_from_func_name(self, func_name, *args, latest=True, **kwargs): + results = {} + if latest: + for dl_idx in range(self.num_dataloaders): + dl_idx = str(dl_idx) + if self._internal_type == self._types[0]: + latest_result = self._internals[dl_idx][-1] + else: + latest_result = self.get_latest_from_dict(dl_idx) + add_dataloader_idx = self.check_dataloader_idx(latest_result) + func = getattr(latest_result, func_name) + results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs)) + return results + raise NotImplementedError + + def get_batch_pbar_metrics(self, latest=True, *args, **kwargs): + return self.get_lastest_from_func_name("get_batch_pbar_metrics", *args, latest=latest, **kwargs) + + def get_batch_log_metrics(self, latest=True, *args, **kwargs): + return self.get_lastest_from_func_name("get_batch_log_metrics", *args, latest=latest, **kwargs) + + def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs): + if isinstance(opt_metric, Result): + func = getattr(opt_metric, func_name) + metrics_to_log = func( + *args, + add_dataloader_idx=self.add_dataloader_idx, + **kwargs) + results.update(metrics_to_log) + else: + raise Exception("The provided opt_metric should be a Result Object. Something is wrong") + + def get_epoch_from_func_name(self, func_name, *args, **kwargs): + results = {} + for dl_idx in range(self.num_dataloaders): + dl_idx = str(dl_idx) + opt_metrics = self._internals_reduced[dl_idx] + if isinstance(opt_metrics, defaultdict): + for opt_metric in opt_metrics.values(): + self.run_epoch_func(results, opt_metric, func_name, *args, **kwargs) + else: + self.run_epoch_func(results, opt_metrics, func_name, *args, **kwargs) + return results + + def get_epoch_pbar_metrics(self, *args, **kwargs): + return self.get_epoch_from_func_name("get_epoch_pbar_metrics") + + def get_epoch_log_metrics(self, *args, **kwargs): + return self.get_epoch_from_func_name("get_epoch_log_metrics") + + def get_forked_metrics(self, *args, **kwargs): + return self.get_epoch_from_func_name("get_forked_metrics") + + @staticmethod + def _append_to_structure(primary_dict, opt_idx, batch_idx, result): + if opt_idx not in primary_dict: + primary_dict[opt_idx] = {} + + if batch_idx not in primary_dict[opt_idx]: + primary_dict[opt_idx][batch_idx] = [] + + primary_dict[opt_idx][batch_idx].append(result) + + def append(self, result, dataloader_idx=None, extra_info: dict = {}): + + assert isinstance(result, Result) + + if dataloader_idx is None: + dataloader_idx = 0 + + primary_key = f"{dataloader_idx}" + + # [dataloader_idx][optimizer_idx][training_step_idx] is a list + if len(extra_info) > 0: + self._internal_type = self._types[-1] + # initialize dictionary + if primary_key not in self._internals: + self._internals[primary_key] = {} + self._internals_reduced[primary_key] = defaultdict(dict) + + # extract infos + opt_idx = str(extra_info["opt_idx"]) + batch_idx = str(extra_info["batch_idx"]) + + self._append_to_structure(self._internals[primary_key], opt_idx, batch_idx, result) + + # [dataloader_idx] is a list + else: + self._internal_type = self._types[0] + if primary_key not in self._internals: + self._internals[primary_key] = [] + self._internals[primary_key].append(result) + + def auto_reduce_results_on_epoch_end(self): + if not self.has_reduced: + epoch_log_metrics = {} + epoch_progress_bar_metrics = {} + + for dl_idx in range(self.num_dataloaders): + dl_idx = str(dl_idx) + epoch_metrics = self._internals[dl_idx] + + if self._internal_type == self._types[-1]: + + num_opt_idx = len(self._internals[dl_idx]) - 1 + + # Make sure we didn't create key + assert num_opt_idx >= 0 + + for opt_idx in range(num_opt_idx + 1): + opt_idx = str(opt_idx) + # TODO: Figure out to reduce memory + # TODO: How to start training in middle of epoch + opt_outputs = deepcopy(epoch_metrics[opt_idx]) + + num_batch_idx = len(self._internals[dl_idx][str(num_opt_idx)]) - 1 + assert num_batch_idx >= 0 + batch_indexes = self._internals[dl_idx][str(num_opt_idx)].keys() + + # reduce across time first + time_reduced_outputs = [] + for batch_idx in batch_indexes: + batch_idx = str(batch_idx) + tbptt_outs = opt_outputs[str(batch_idx)] + tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs) + if len(tbptt_outs) > 1: + time_reduced_outputs.append(tbptt_outs) + + if len(time_reduced_outputs) == 0: + continue + + # reduce across training steps + opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs) + + # with manual opt need 1 + metrics because meta is always there + if opt_outputs.minimize is not None: + opt_outputs.minimize = opt_outputs.minimize.mean() + + self._internals_reduced[dl_idx][str(opt_idx)] = opt_outputs + else: + # no need to reduce as called only once + if len(epoch_metrics) == 1: + reduced_epoch_metrics = deepcopy(epoch_metrics[0]) + else: + reduced_epoch_metrics = epoch_metrics[0].__class__.reduce_on_epoch_end(deepcopy(epoch_metrics)) + + self._internals_reduced[dl_idx] = reduced_epoch_metrics + + self.has_reduced = True + + def __getitem__(self, key: str) -> Any: + try: + if key in self._internals: + return self._internals[key] + return self[key] + except KeyError: + return None + + def __repr__(self): + return self._internals.__repr__() + + +class EpochResultStore: + """ + This class is responsible to cache all logging metrics which happened during one epoch + + It will cache Result objects as follow. + + self._internals = {"fx_name_0": HookResult(), ..., "fx_name_n": HookResult()} + """ + def __init__(self, trainer, stage): + self.trainer = trainer + self._stage = stage + self.reset() + + def __getitem__(self, key: str) -> Any: + try: + if key in self._internals: + return self._internals[key] + return None + except KeyError: + return None + + @property + def has_split_and_opt_idx(self): + if self._split_idx is not None and self._opt_idx is not None: + return True + return False + + @property + def extra_info(self): + return {"batch_idx": self.trainer.batch_idx, + "split_idx": self._split_idx, + "opt_idx": self._opt_idx} + + def reset_model(self): + model_ref = self.trainer.get_model() + model_ref._results = Result() + model_ref._current_hook_fx_name = '' + model_ref._current_fx_name = '' + + def current_model_info(self): + model_ref = self.trainer.get_model() + # extract hook information + fx_name = model_ref._current_hook_fx_name + if fx_name == '': + fx_name = model_ref._current_fx_name + dataloader_idx = model_ref._current_dataloader_idx + return fx_name, dataloader_idx + + def cache_result(self): + """ + This function is called after every hook + and store the result object + """ + model_ref = self.trainer.get_model() + + # extract hook results + hook_result = model_ref._results + + # extract model information + fx_name, dataloader_idx = self.current_model_info() + + # add only if anything as been logged + # default len is 1 due to _internals + if len(hook_result) > 1: + + if fx_name not in self._internals: + self._internals[fx_name] = HookResultStore(fx_name) + + extra_info = {} + if self.has_split_and_opt_idx: + extra_info = self.extra_info + + # attach capture batch_size + Result.attach_batch_size(self._batch_size, hook_result) + + self._internals[fx_name].append( + deepcopy(hook_result), + dataloader_idx=dataloader_idx, + extra_info=extra_info) + + # update logged_metrics, progress_bar_metrics, callback_metrics + self.update_logger_connector(fx_name) + + # reset _results, fx_name + self.reset_model() + + def update_logger_connector(self, fx_name=None): + """ + This function is called every time we capture a hook + It automatically updates the logger_connector followings: + - progress_bar_metrics with pbar_metrics + - logged_metrics with log_metrics + - callback_metrics with progress_bar_metrics + logged_metrics + """ + + logger_connector = self.trainer.logger_connector + + callback_metrics = {} + + if not self._has_batch_loop_finished: + # get pbar + batch_pbar_metrics = self.get_latest_batch_pbar_metrics() + logger_connector.add_progress_bar_metrics(batch_pbar_metrics) + + if self._stage in LoggerStages.TRAIN.value: + # Only log and add to callback epoch step during evaluation, test. + batch_log_metrics = self.get_latest_batch_log_metrics() + logger_connector.logged_metrics.update(batch_log_metrics) + + callback_metrics.update(batch_pbar_metrics) + callback_metrics.update(batch_log_metrics) + else: + epoch_dict = {"epoch": self.trainer.current_epoch} + + # get pbar + epoch_pbar_metrics = self.get_epoch_pbar_metrics() + logger_connector.add_progress_bar_metrics(epoch_pbar_metrics) + + # get logged_metrics + epoch_log_metrics = self.get_epoch_log_metrics() + logger_connector.logged_metrics.update(epoch_log_metrics) + logger_connector.logged_metrics.update(epoch_dict) + + # get forked_metrics + forked_metrics = self.get_forked_metrics() + + callback_metrics.update(epoch_pbar_metrics) + callback_metrics.update(epoch_log_metrics) + callback_metrics.update(forked_metrics) + + # update callback_metrics + logger_connector.callback_metrics.update(callback_metrics) + logger_connector.callback_metrics.pop("epoch", None) + + def get_latest_batch_log_metrics(self): + results = {} + for fx_name, hook_result in self._internals.items(): + results.update(hook_result.get_batch_log_metrics( + latest=True, + include_forked_originals=False)) + return results + + def get_latest_batch_pbar_metrics(self): + results = {} + for fx_name, hook_result in self._internals.items(): + results.update(hook_result.get_batch_pbar_metrics( + latest=True, + include_forked_originals=False)) + return results + + @property + def has_reduced(self): + hook_results = self._internals.values() + return len(hook_results) == sum([h.has_reduced for h in hook_results]) + + def auto_reduce_results_on_epoch_end(self): + if not self.has_reduced: + for fx_name, hook_result in self._internals.items(): + hook_result.auto_reduce_results_on_epoch_end() + + @property + def has_batch_loop_finished(self): + return self._has_batch_loop_finished + + @has_batch_loop_finished.setter + def has_batch_loop_finished(self, has_batch_loop_finished): + if has_batch_loop_finished: + # If batch loop has finished, reduce metrics + self.auto_reduce_results_on_epoch_end() + + # batch_size should be none as we finished batch loop + self._batch_size = None + + self._has_batch_loop_finished = has_batch_loop_finished + self.update_logger_connector() + + def get_epoch_pbar_metrics(self): + if not self.has_reduced: + self.auto_reduce_results_on_epoch_end() + epoch_pbar_metrics = {} + for fx_name, hook_result in self._internals.items(): + epoch_pbar_metrics.update(hook_result.get_epoch_pbar_metrics()) + return epoch_pbar_metrics + + def get_epoch_log_metrics(self): + if not self.has_reduced: + self.auto_reduce_results_on_epoch_end() + epoch_log_metrics = {} + for fx_name, hook_result in self._internals.items(): + epoch_log_metrics.update(hook_result.get_epoch_log_metrics()) + return epoch_log_metrics + + def get_forked_metrics(self): + if not self.has_reduced: + self.auto_reduce_results_on_epoch_end() + forked_metrics = {} + for fx_name, hook_result in self._internals.items(): + forked_metrics.update(hook_result.get_forked_metrics()) + return forked_metrics + + def get_reduced_metrics(self): + if not self.has_reduced: + self.auto_reduce_results_on_epoch_end() + reduced_metrics = {} + for fx_name, hook_result in self._internals.items(): + reduced_metrics[fx_name] = hook_result.get_reduced_metrics() + return reduced_metrics + + def __repr__(self): + return f"{self.__class__.__name__}(stage={self._stage}, internals={self._internals})" + + def reset(self): + self._internals = {} + self._dataloader_idx = None + self._split_idx = None + self._opt_idx = None + self._batch_size = None + self._has_batch_loop_finished = False + self._num_dataloaders = 1 diff --git a/pytorch_lightning/trainer/connectors/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py similarity index 86% rename from pytorch_lightning/trainer/connectors/logger_connector.py rename to pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 893eab5a16a3d..0918c31ffd974 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from pprint import pprint +from typing import Iterable, Union +from copy import deepcopy +from collections import ChainMap import torch from pytorch_lightning.core import memory from pytorch_lightning.loggers import TensorBoardLogger, LoggerCollection @@ -19,20 +23,87 @@ from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pprint import pprint -from typing import Iterable -from copy import deepcopy -from collections import ChainMap +from pytorch_lightning.trainer.connectors.logger_connector import ( + CallbackHookNameValidator, + EpochResultStore, + LoggerStages +) class LoggerConnector: + __lookup_stages = {"1": "test", "0": "validation", "True": "test", "False": "validation"} + def __init__(self, trainer): self.trainer = trainer self.callback_metrics = {} self.logged_metrics = {} self.progress_bar_metrics = {} self.eval_loop_results = [] + self._stages = sorted([s.value for s in LoggerStages]) + self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in self._stages} + self._callback_hook_validator = CallbackHookNameValidator() + self._current_stage = None + + def cached_results(self, stage_or_testing: Union[str, bool]) -> Union[EpochResultStore, None]: + """ Function to access cached_results using str or bool. Bool is used only for testing""" + stage_or_testing = str(stage_or_testing) + stages = self._stages + if stage_or_testing in self._stages: + return self._cached_results[stage_or_testing] + if stage_or_testing in self.__lookup_stages: + # Acces using trainer.testing + stage = self.__lookup_stages[stage_or_testing] + return self._cached_results[stage] + raise MisconfigurationException( + f"Provide stage_or_testing {stage_or_testing} doesn't belong either to {self._stages}" + f" or {self.__lookup_stages.keys()}" + ) + + def set_stage(self, stage_or_testing: str, reset:bool = False) -> None: + self._current_stage = self._determine_stage(stage_or_testing) + if reset: + self.cached_results(stage_or_testing).reset() + + def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoch: bool = None) -> None: + self._callback_hook_validator.check_logging_in_callbacks(current_hook_fx_name=hook_fx_name, + on_step=on_step, + on_epoch=on_epoch) + + def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataloaders): + # reset the result of the PL module + model = self.trainer.get_model() + model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None + + # track batch_size + self.cached_results(testing)._batch_size = Result.extract_batch_size(batch) + + def on_batch_start(self, split_idx: int, opt_idx: int, split_batch) -> None: + self._cached_results["train"]._split_idx = split_idx + self._cached_results["train"]._opt_idx = opt_idx + self._cached_results["train"]._batch_size = Result.extract_batch_size(split_batch) + + def on_train_batch_end(self) -> None: + self._cached_results["train"]._split_idx = None + self._cached_results["train"]._opt_idx = None + self._cached_results["train"]._batch_size = None + + def _determine_stage(self, stage_or_testing: Union[str, bool]) -> str: + stage_or_testing = str(stage_or_testing) + stages = self._stages + if stage_or_testing in stages: + return stage_or_testing + if stage_or_testing in self.__lookup_stages: + # Acces using trainer.testing + return self.__lookup_stages[stage_or_testing] + raise MisconfigurationException( + f"Provide stage_or_testing {stage_or_testing} doesn't belong either to {stages}" + f" or {self.__lookup_stages.keys()}" + ) + + def cache_logged_metrics(self) -> Union[EpochResultStore, None]: + if self._current_stage is not None: + self._cached_results[self._current_stage].cache_result() def on_trainer_init(self, logger, flush_logs_every_n_steps, log_every_n_steps): # logging diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 9dab036583dd8..eb9020f91903d 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -29,6 +29,7 @@ def __init__(self, trainer): self.predictions = None self.max_batches = None self.warning_cache = WarningCache() + self.num_dataloaders = None def on_trainer_init(self): self.trainer.num_val_batches = [] @@ -108,6 +109,9 @@ def on_evaluation_end(self, *args, **kwargs): else: self.trainer.call_hook('on_validation_end', *args, **kwargs) + # reset stage to train + self.trainer.logger_connector.set_stage("train") + def reload_evaluation_dataloaders(self): model = self.trainer.get_model() if self.testing: @@ -133,6 +137,7 @@ def setup(self, model, max_batches, dataloaders): max_batches = [max_batches] * len(dataloaders) self.max_batches = max_batches + self.num_dataloaders = self._get_num_dataloaders(dataloaders) def on_evaluation_epoch_start(self, *args, **kwargs): if self.testing: @@ -292,16 +297,20 @@ def __auto_reduce_result_objs(self, outputs): return eval_results - def on_evaluation_batch_start(self, *args, **kwargs): + def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx): # reset the result of the PL module model = self.trainer.get_model() model._results = Result() model._current_fx_name = 'evaluation_step' + # set dataloader_idx and track batch_size + self.trainer.logger_connector.on_evaluation_batch_start( + self.testing, batch, dataloader_idx, self.num_dataloaders) + if self.testing: - self.trainer.call_hook('on_test_batch_start', *args, **kwargs) + self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) else: - self.trainer.call_hook('on_validation_batch_start', *args, **kwargs) + self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) def on_evaluation_batch_end(self, *args, **kwargs): if self.testing: diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index b585647fb5a0e..ce4a1d7b1fcd0 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -92,7 +92,7 @@ def process_dict_result(self, output, train=False): # --------------- # all keys not progress_bar or log are candidates for callbacks callback_metrics = {} - if output: + if isinstance(output, dict): for k, v in output.items(): if k not in ['progress_bar', 'log', 'hiddens']: callback_metrics[k] = v @@ -156,7 +156,7 @@ def process_dict_result(self, output, train=False): # --------------- # EXTRACT HIDDEN # --------------- - hiddens = output.get('hiddens') if output else None + hiddens = output.get('hiddens', None) if isinstance(output, dict) else None # use every metric passed in as a candidate for callback callback_metrics.update(progress_bar_metrics) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 008633273a0d1..fbc2748456b26 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -22,7 +22,8 @@ from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.core.step_result import EvalResult +from pytorch_lightning.core.memory import ModelSummary +from pytorch_lightning.core.step_result import Result, EvalResult from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.profiler import BaseProfiler from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin @@ -465,6 +466,9 @@ def fit( def train(self): self.run_sanity_check(self.get_model()) + # set stage for logging + self.logger_connector.set_stage("train") + self.checkpoint_connector.has_trained = False # enable train mode @@ -528,16 +532,25 @@ def train(self): self.train_loop.on_train_end() def run_evaluation(self, test_mode: bool = False, max_batches=None): + + # used to know if we are logging for val, test + reset cached results + self.logger_connector.set_stage(test_mode, reset=True) + # bookkeeping self.evaluation_loop.testing = test_mode + + # prepare dataloaders dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches) + + # check if we want to skip this evaluation if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches): return [], [] - # enable eval mode + no grads + # ref model model = self.get_model() - self.evaluation_loop.on_evaluation_model_eval() + # enable eval mode + no grads + self.evaluation_loop.on_evaluation_model_eval() model.zero_grad() torch.set_grad_enabled(False) @@ -701,6 +714,8 @@ def test( # -------------------- self.verbose_test = verbose + self.logger_connector.set_stage("test") + # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d1dfb3eec3733..f382f284e871b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -661,20 +661,12 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): splits = self.tbptt_split_batch(batch) for split_idx, split_batch in enumerate(splits): - self.trainer.split_idx = split_idx - # in manual optimization we loop over all optimizers at once - optimizers = self.get_optimizers_iterable() - if not self.automatic_optimization: - optimizers = [optimizers[0]] + # create an iterable for optimizers and loop over them + for opt_idx, optimizer in self.prepare_optimizers(): - # loop over optimizers - for opt_idx, optimizer in optimizers: - # make sure only the gradients of the current optimizer's parameters are calculated - # in the training step to prevent dangling gradients in multiple-optimizer setup. - if self.automatic_optimization and len(self.trainer.optimizers) > 1: - model = self.trainer.get_model() - model.toggle_optimizer(optimizer, opt_idx) + # toggle model params + set info to logger_connector + self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) if should_accumulate: # For gradient accumulation @@ -729,6 +721,7 @@ def train_step_and_backward_closure(): opt_idx=opt_idx, ) + # todo: Properly aggregate grad_norm accros opt_idx and split_idx grad_norm_dic = self._cur_grad_norm_dict self._cur_grad_norm_dict = None @@ -738,14 +731,8 @@ def train_step_and_backward_closure(): # clear gradients self.optimizer_zero_grad(batch_idx, optimizer, opt_idx) - accumulated_loss = self.accumulated_loss.mean() - - if accumulated_loss is not None: - # calculate running loss for display - self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) - - # reset for next set of accumulated grads - self.accumulated_loss.reset() + # update running loss + reset accumulated loss + self.update_running_loss() # collapse all metrics into one dict batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()} @@ -934,3 +921,33 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu epoch_end_outputs.append(optimizer_idx_outputs) return epoch_end_outputs + + def prepare_optimizers(self): + # in manual optimization we loop over all optimizers at once + optimizers = self.get_optimizers_iterable() + if not self.automatic_optimization: + optimizers = [optimizers[0]] + return optimizers + + def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): + # set split_idx to trainer for tracking + self.trainer.split_idx = split_idx + + # make sure only the gradients of the current optimizer's parameters are calculated + # in the training step to prevent dangling gradients in multiple-optimizer setup. + if self.automatic_optimization and len(self.trainer.optimizers) > 1: + model = self.trainer.get_model() + model.toggle_optimizer(optimizer, opt_idx) + + # use to track metrics internally + self.trainer.logger_connector.on_batch_start(split_idx, opt_idx, split_batch) + + def update_running_loss(self): + accumulated_loss = self.accumulated_loss.mean() + + if accumulated_loss is not None: + # calculate running loss for display + self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) + + # reset for next set of accumulated grads + self.accumulated_loss.reset() diff --git a/tests/trainer/logging/test_logger_connector.py b/tests/trainer/logging/test_logger_connector.py new file mode 100644 index 0000000000000..4097af4668fc5 --- /dev/null +++ b/tests/trainer/logging/test_logger_connector.py @@ -0,0 +1,249 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests to ensure that the training loop works with a dict (1.0) +""" +import os +import torch +import pytest + +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector +from tests.base.boring_model import BoringModel, RandomDataset + + +class Helper: + def decorator_with_arguments(fx_name='', hook_fx_name=''): + def decorator(func): + def wrapper(self, *args, **kwargs): + # Set information + self._current_fx_name = fx_name + self._current_hook_fx_name = hook_fx_name + self._results = Result() + + result = func(self, *args, **kwargs) + + # cache metrics + self.trainer.logger_connector.cache_logged_metrics() + return result + return wrapper + + return decorator + + +def test__logger_connector__epoch_result_store__train(tmpdir): + """ + Tests that LoggerConnector will properly capture logged information + and reduce them + """ + + os.environ['PL_DEV_DEBUG'] = '1' + + class TestModel(BoringModel): + + train_losses = [] + + @Helper.decorator_with_arguments(fx_name="training_step") + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + + self.train_losses.append(loss) + + self.log("train_loss", loss, on_step=True, on_epoch=True) + return {"loss": loss} + + def val_dataloader(self): + return [torch.utils.data.DataLoader(RandomDataset(32, 64)), + torch.utils.data.DataLoader(RandomDataset(32, 64))] + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=4, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + assert len(trainer.logger_connector.cached_results("train")['training_step']['0']['0']) == 2 + assert trainer.logger_connector.cached_results("train")['training_step']['0']['0']['0'][0]["train_loss"] == model.train_losses[0] + assert trainer.logger_connector.cached_results("train")['training_step']['0']['0']['1'][0]["train_loss"] == model.train_losses[1] + + # assert reduction didn't happen yet + assert trainer.logger_connector.cached_results("train").has_reduced is False + + # Launch reduction + trainer.logger_connector.cached_results("train").has_batch_loop_finished is True + + # assert reduction did happen + assert trainer.logger_connector.cached_results("train").has_reduced is True + + assert trainer.logger_connector.cached_results("train")["training_step"]\ + ._internals_reduced["0"]["0"]['train_loss_epoch'] == torch.stack(model.train_losses).mean() + + +def test__logger_connector__epoch_result_store__train__ttbt(tmpdir): + """ + Tests that LoggerConnector will properly capture logged information with ttbt + and reduce them + """ + truncated_bptt_steps = 2 + sequence_size = 30 + batch_size = 30 + + x_seq = torch.rand(batch_size, sequence_size, 1) + y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() + + class MockSeq2SeqDataset(torch.utils.data.Dataset): + def __getitem__(self, i): + return x_seq, y_seq_list + + def __len__(self): + return 1 + + class TestModel(BoringModel): + + train_losses = [] + + def __init__(self): + super().__init__() + self.test_hidden = None + self.layer = torch.nn.Linear(2, 2) + + @Helper.decorator_with_arguments(fx_name="training_step") + def training_step(self, batch, batch_idx, hiddens): + try: + assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" + except Exception as e: + print(e) + + self.test_hidden = torch.rand(1) + + x_tensor, y_list = batch + assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" + + y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) + assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" + + pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) + loss = torch.nn.functional.mse_loss( + pred, y_tensor.view(batch_size, truncated_bptt_steps)) + + self.train_losses.append(loss) + + self.log('a', loss, on_epoch=True) + + return {'loss': loss, 'hiddens': self.test_hidden} + + def on_train_epoch_start(self) -> None: + self.test_hidden = None + + def train_dataloader(self): + return torch.utils.data.DataLoader( + dataset=MockSeq2SeqDataset(), + batch_size=batch_size, + shuffle=False, + sampler=None, + ) + + model = TestModel() + model.training_epoch_end = None + model.example_input_array = torch.randn(5, truncated_bptt_steps) + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=10, + limit_val_batches=0, + truncated_bptt_steps=truncated_bptt_steps, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + assert len(trainer.logger_connector.cached_results("train")['training_step']['0']['0']['0']) == len(model.train_losses) + + # assert reduction didn't happen yet + assert trainer.logger_connector.cached_results("train").has_reduced is False + + # Launch reduction + trainer.logger_connector.cached_results("train").has_batch_loop_finished is True + + # assert reduction did happen + assert trainer.logger_connector.cached_results("train").has_reduced is True + + assert trainer.logger_connector.cached_results("train")['training_step']\ + ._internals_reduced['0']['0']["a_epoch"] == torch.stack(model.train_losses).mean() + + +@pytest.mark.parametrize('num_dataloaders', [1, 2]) +def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, num_dataloaders): + """ + Tests that LoggerConnector will properly capture logged information in multi_dataloaders scenario + """ + + os.environ['PL_DEV_DEBUG'] = '1' + + class TestModel(BoringModel): + + test_losses = {} + + @Helper.decorator_with_arguments(fx_name="test_step") + def test_step(self, batch, batch_idx, dataloader_idx=0): + output = self.layer(batch) + loss = self.loss(batch, output) + + primary_key = str(dataloader_idx) + if primary_key not in self.test_losses: + self.test_losses[primary_key] = [] + + self.test_losses[primary_key].append(loss) + + self.log("test_loss", loss, on_step=True, on_epoch=True) + return {"test_loss": loss} + + def test_dataloader(self): + return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)] + + model = TestModel() + model.val_dataloader = None + model.test_epoch_end = None + + limit_test_batches = 4 + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=0, + limit_val_batches=0, + limit_test_batches=limit_test_batches, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.test(model) + + assert len(trainer.logger_connector.cached_results("test")["test_step"]._internals) == num_dataloaders + for dl_idx in range(num_dataloaders): + assert len(trainer.logger_connector.cached_results("test")["test_step"]._internals[str(dl_idx)]) == limit_test_batches + trainer.logger_connector.cached_results("test").has_batch_loop_finished = True + for dl_idx in range(num_dataloaders): + expected = torch.stack(model.test_losses[str(dl_idx)]).mean() + generated = trainer.logger_connector.cached_results("test")["test_step"]._internals_reduced[str(dl_idx)]["test_loss_epoch"] + assert expected == generated From 8417a15ea558cb1d8ff4f5d79b4d8b7a43373c71 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 30 Oct 2020 10:08:41 +0000 Subject: [PATCH 02/15] typo --- tests/trainer/logging/test_logger_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/logging/test_logger_connector.py b/tests/trainer/logging/test_logger_connector.py index 4097af4668fc5..b40952a839130 100644 --- a/tests/trainer/logging/test_logger_connector.py +++ b/tests/trainer/logging/test_logger_connector.py @@ -90,7 +90,7 @@ def val_dataloader(self): assert trainer.logger_connector.cached_results("train").has_reduced is False # Launch reduction - trainer.logger_connector.cached_results("train").has_batch_loop_finished is True + trainer.logger_connector.cached_results("train").has_batch_loop_finished = True # assert reduction did happen assert trainer.logger_connector.cached_results("train").has_reduced is True @@ -184,7 +184,7 @@ def train_dataloader(self): assert trainer.logger_connector.cached_results("train").has_reduced is False # Launch reduction - trainer.logger_connector.cached_results("train").has_batch_loop_finished is True + trainer.logger_connector.cached_results("train").has_batch_loop_finished = True # assert reduction did happen assert trainer.logger_connector.cached_results("train").has_reduced is True From e21b249c7d8bf2b25642402bd0c29474cf8f9557 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 30 Oct 2020 10:40:26 +0000 Subject: [PATCH 03/15] typo --- .../connectors/logger_connector/logger_connector.py | 9 +++------ tests/trainer/logging/test_logger_connector.py | 6 +++--- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 0918c31ffd974..92326e77708ac 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -522,8 +522,7 @@ def __auto_reduce_results_on_epoch_end(self, epoch_output): for opt_outputs in epoch_output: # reduce across time first time_reduced_outputs = [] - for train_step_idx in range(len(opt_outputs)): - tbptt_outs = opt_outputs[train_step_idx] + for tbptt_outs in opt_outputs: tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs) if len(tbptt_outs) > 1: time_reduced_outputs.append(tbptt_outs) @@ -553,8 +552,7 @@ def __prepare_epoch_end_inputs(self, epoch_output): for opt_outputs in epoch_output: # gather across time first time_gathered_outputs = [] - for train_step_idx in range(len(opt_outputs)): - tbptt_outs = opt_outputs[train_step_idx] + for tbptt_outs in opt_outputs: result = [] for x in tbptt_outs: out = x.extra @@ -582,8 +580,7 @@ def __gather_result_across_time_and_optimizers(self, epoch_output): for opt_outputs in epoch_output: # gather across time first time_gathered_outputs = [] - for train_step_idx in range(len(opt_outputs)): - tbptt_outs = opt_outputs[train_step_idx] + for tbptt_outs in opt_outputs: tbptt_outs = tbptt_outs[0].__class__.gather(tbptt_outs) time_gathered_outputs.append(tbptt_outs) diff --git a/tests/trainer/logging/test_logger_connector.py b/tests/trainer/logging/test_logger_connector.py index b40952a839130..d891619f6140a 100644 --- a/tests/trainer/logging/test_logger_connector.py +++ b/tests/trainer/logging/test_logger_connector.py @@ -96,7 +96,7 @@ def val_dataloader(self): assert trainer.logger_connector.cached_results("train").has_reduced is True assert trainer.logger_connector.cached_results("train")["training_step"]\ - ._internals_reduced["0"]["0"]['train_loss_epoch'] == torch.stack(model.train_losses).mean() + ._internals_reduced["0"]["0"]['train_loss_epoch'].item() == torch.stack(model.train_losses).mean().item() def test__logger_connector__epoch_result_store__train__ttbt(tmpdir): @@ -190,7 +190,7 @@ def train_dataloader(self): assert trainer.logger_connector.cached_results("train").has_reduced is True assert trainer.logger_connector.cached_results("train")['training_step']\ - ._internals_reduced['0']['0']["a_epoch"] == torch.stack(model.train_losses).mean() + ._internals_reduced['0']['0']["a_epoch"].item() == torch.stack(model.train_losses).mean().item() @pytest.mark.parametrize('num_dataloaders', [1, 2]) @@ -246,4 +246,4 @@ def test_dataloader(self): for dl_idx in range(num_dataloaders): expected = torch.stack(model.test_losses[str(dl_idx)]).mean() generated = trainer.logger_connector.cached_results("test")["test_step"]._internals_reduced[str(dl_idx)]["test_loss_epoch"] - assert expected == generated + assert expected.item() == generated.item() From f9022d78de45b1c9b4ce6b7709b2748f5784c55c Mon Sep 17 00:00:00 2001 From: chaton Date: Fri, 30 Oct 2020 11:20:17 +0000 Subject: [PATCH 04/15] Update pytorch_lightning/trainer/logging.py Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- pytorch_lightning/trainer/logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index ce4a1d7b1fcd0..de5484973e390 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -92,7 +92,7 @@ def process_dict_result(self, output, train=False): # --------------- # all keys not progress_bar or log are candidates for callbacks callback_metrics = {} - if isinstance(output, dict): + if isinstance(output, Mapping): for k, v in output.items(): if k not in ['progress_bar', 'log', 'hiddens']: callback_metrics[k] = v From f7fd584eebebe970bc724b8725a38a865b34aec5 Mon Sep 17 00:00:00 2001 From: chaton Date: Fri, 30 Oct 2020 11:20:28 +0000 Subject: [PATCH 05/15] Update pytorch_lightning/trainer/logging.py Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- pytorch_lightning/trainer/logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index de5484973e390..d282465d6baab 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -156,7 +156,7 @@ def process_dict_result(self, output, train=False): # --------------- # EXTRACT HIDDEN # --------------- - hiddens = output.get('hiddens', None) if isinstance(output, dict) else None + hiddens = output.get('hiddens', None) if isinstance(output, Mapping) else None # use every metric passed in as a candidate for callback callback_metrics.update(progress_bar_metrics) From de910e15f431395af66d31d37c8863731183ba5b Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 30 Oct 2020 11:22:01 +0000 Subject: [PATCH 06/15] update on comments --- pytorch_lightning/core/lightning.py | 4 ++-- .../trainer/connectors/logger_connector/__init__.py | 2 -- .../trainer/connectors/logger_connector/epoch_loop_result.py | 2 +- .../trainer/connectors/logger_connector/logger_connector.py | 4 ++-- pytorch_lightning/trainer/logging.py | 2 +- tests/trainer/logging/test_logger_connector.py | 2 +- 6 files changed, 7 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d0c7cf1fdb465..623e532777044 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -111,7 +111,7 @@ def __init__(self, *args, **kwargs): self._datamodule = None self._results: Optional[Result] = None self._current_fx_name = '' - self._current_hook_fx_name = '' + self._current_hook_fx_name = None self._current_dataloader_idx = None def optimizers(self): @@ -246,7 +246,7 @@ def log( on_step = self.__auto_choose_log_on_step(on_step) on_epoch = self.__auto_choose_log_on_epoch(on_epoch) - if self._current_hook_fx_name != '': + if self._current_hook_fx_name is not None: self.trainer.logger_connector\ .check_logging_in_callbacks(self._current_hook_fx_name, on_step=on_step, diff --git a/pytorch_lightning/trainer/connectors/logger_connector/__init__.py b/pytorch_lightning/trainer/connectors/logger_connector/__init__.py index dd3b1c2e247e7..4034840a09b97 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/__init__.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/__init__.py @@ -1,3 +1 @@ -from pytorch_lightning.trainer.connectors.logger_connector.epoch_loop_result import EpochResultStore, LoggerStages -from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator from pytorch_lightning.trainer.connectors.logger_connector.logger_connector import LoggerConnector diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_loop_result.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_loop_result.py index 4b3af4f4af8da..cdb3526fe19ac 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_loop_result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_loop_result.py @@ -285,7 +285,7 @@ def extra_info(self): def reset_model(self): model_ref = self.trainer.get_model() model_ref._results = Result() - model_ref._current_hook_fx_name = '' + model_ref._current_hook_fx_name = None model_ref._current_fx_name = '' def current_model_info(self): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 92326e77708ac..8625a7112168f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -23,8 +23,8 @@ from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.trainer.connectors.logger_connector import ( - CallbackHookNameValidator, +from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator +from pytorch_lightning.trainer.connectors.logger_connector.epoch_loop_result import ( EpochResultStore, LoggerStages ) diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index d282465d6baab..ae4d280d54649 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -14,7 +14,7 @@ from abc import ABC import inspect -from typing import Union, Iterable +from typing import Union, Iterable, Mapping import torch diff --git a/tests/trainer/logging/test_logger_connector.py b/tests/trainer/logging/test_logger_connector.py index d891619f6140a..0f27f2ca4fef4 100644 --- a/tests/trainer/logging/test_logger_connector.py +++ b/tests/trainer/logging/test_logger_connector.py @@ -246,4 +246,4 @@ def test_dataloader(self): for dl_idx in range(num_dataloaders): expected = torch.stack(model.test_losses[str(dl_idx)]).mean() generated = trainer.logger_connector.cached_results("test")["test_step"]._internals_reduced[str(dl_idx)]["test_loss_epoch"] - assert expected.item() == generated.item() + assert abs(expected.item() - generated.item()) < 1e-6 From eaf349bae9d73f01d3ab31727e339b38476d9128 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 30 Oct 2020 11:42:42 +0000 Subject: [PATCH 07/15] update on comments --- .../logger_connector/epoch_loop_result.py | 102 ++++++++++-------- 1 file changed, 59 insertions(+), 43 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_loop_result.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_loop_result.py index cdb3526fe19ac..1196a35994a56 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_loop_result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_loop_result.py @@ -28,22 +28,32 @@ class LoggerStages(Enum): class HookResultStore: """ - This class is used to hold all metrics logged during one callback or model hook. - Can be used for both training, val, test. - - Result objects will be stored in the following way. - - val and test: self._internals = {"dataloader_idx": [Result(), ..., Result()]} - training - - IF optimizer_idx and training_step_idx are set, - THEN self._internals = {"dataloader_idx": + This class is defined for internal usage. + It holds all metrics logged using the self.log function + in the scope of ModelHooks or Callback functions. + + We need to differiante 3 different scenarios: + - (1): We are outside of a batch loop + * It means no dataloader_idx, no optimizer idx, etc.. + - (2): We are inside the training batch loop + * We have an optimizer idx and split idx to track + - (3): We are inside the evaluation loop + * We have a dataloader_idx to track + + The data store `Result` objects for those 3 scenarios in `self._internals`. + + (1): self._internals = {"dataloader_idx": [Result(), ..., Result()]} + * dataloader_idx not being defined, it is set to 0 b default + (2): self._internals = {"dataloader_idx": {"optimizer_idx": - {"training_step_idx": + {"batch_idx": [Result(), Result()] } } } - - ELSE self._internals = {"dataloader_idx": [Result(), ..., Result()]} + (3): Same as (1) for simplicity + + Those data structures enables us to reduce properly Result object when batch loop is finished. """ _types = ["list", "dict"] @@ -200,7 +210,7 @@ def auto_reduce_results_on_epoch_end(self): opt_idx = str(opt_idx) # TODO: Figure out to reduce memory # TODO: How to start training in middle of epoch - opt_outputs = deepcopy(epoch_metrics[opt_idx]) + opt_outputs = epoch_metrics[opt_idx] num_batch_idx = len(self._internals[dl_idx][str(num_opt_idx)]) - 1 assert num_batch_idx >= 0 @@ -226,15 +236,21 @@ def auto_reduce_results_on_epoch_end(self): opt_outputs.minimize = opt_outputs.minimize.mean() self._internals_reduced[dl_idx][str(opt_idx)] = opt_outputs + + # free memory + del self._internals[dl_idx] else: # no need to reduce as called only once if len(epoch_metrics) == 1: - reduced_epoch_metrics = deepcopy(epoch_metrics[0]) + reduced_epoch_metrics = epoch_metrics[0] else: - reduced_epoch_metrics = epoch_metrics[0].__class__.reduce_on_epoch_end(deepcopy(epoch_metrics)) + reduced_epoch_metrics = epoch_metrics[0].__class__.reduce_on_epoch_end(epoch_metrics) self._internals_reduced[dl_idx] = reduced_epoch_metrics + # free memory + del self._internals[dl_idx] + self.has_reduced = True def __getitem__(self, key: str) -> Any: @@ -251,11 +267,22 @@ def __repr__(self): class EpochResultStore: """ - This class is responsible to cache all logging metrics which happened during one epoch + This class is defined for internal usage. + + It holds all metrics logged using the self.log function using `HookResultStore` object. + + The internal datastructure is as follow: - It will cache Result objects as follow. + self._internals = {"fx_name_0": HookResultStore(), ..., "fx_name_n": HookResultStore()} + + Pseudo Code Example: + ``` + model._current_fx_name = 'something' + model._results = Result() + model.log('a', ...) + epoch_result_store.cache_result() + ``` - self._internals = {"fx_name_0": HookResult(), ..., "fx_name_n": HookResult()} """ def __init__(self, trainer, stage): self.trainer = trainer @@ -283,6 +310,7 @@ def extra_info(self): "opt_idx": self._opt_idx} def reset_model(self): + # reset model to its capture state model_ref = self.trainer.get_model() model_ref._results = Result() model_ref._current_hook_fx_name = None @@ -425,40 +453,26 @@ def has_batch_loop_finished(self, has_batch_loop_finished): self._has_batch_loop_finished = has_batch_loop_finished self.update_logger_connector() - def get_epoch_pbar_metrics(self): + def run_by_func_name(self, func_name): if not self.has_reduced: self.auto_reduce_results_on_epoch_end() - epoch_pbar_metrics = {} + results = {} for fx_name, hook_result in self._internals.items(): - epoch_pbar_metrics.update(hook_result.get_epoch_pbar_metrics()) - return epoch_pbar_metrics + func = getattr(hook_result, func_name) + results.update(func()) + return results + + def get_epoch_pbar_metrics(self): + return self.run_by_func_name("get_epoch_pbar_metrics") def get_epoch_log_metrics(self): - if not self.has_reduced: - self.auto_reduce_results_on_epoch_end() - epoch_log_metrics = {} - for fx_name, hook_result in self._internals.items(): - epoch_log_metrics.update(hook_result.get_epoch_log_metrics()) - return epoch_log_metrics + return self.run_by_func_name("get_epoch_log_metrics") def get_forked_metrics(self): - if not self.has_reduced: - self.auto_reduce_results_on_epoch_end() - forked_metrics = {} - for fx_name, hook_result in self._internals.items(): - forked_metrics.update(hook_result.get_forked_metrics()) - return forked_metrics + return self.run_by_func_name("get_forked_metrics") def get_reduced_metrics(self): - if not self.has_reduced: - self.auto_reduce_results_on_epoch_end() - reduced_metrics = {} - for fx_name, hook_result in self._internals.items(): - reduced_metrics[fx_name] = hook_result.get_reduced_metrics() - return reduced_metrics - - def __repr__(self): - return f"{self.__class__.__name__}(stage={self._stage}, internals={self._internals})" + return self.run_by_func_name("get_reduced_metrics") def reset(self): self._internals = {} @@ -467,4 +481,6 @@ def reset(self): self._opt_idx = None self._batch_size = None self._has_batch_loop_finished = False - self._num_dataloaders = 1 + + def __repr__(self): + return f"{self.__class__.__name__}(stage={self._stage}, internals={self._internals})" From b81aecfa2871695324bf2b1cc68f5f36636ae1ce Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 30 Oct 2020 11:57:24 +0000 Subject: [PATCH 08/15] add more doctstring --- ...h_loop_result.py => epoch_result_store.py} | 84 +++++++++++-------- .../logger_connector/logger_connector.py | 2 +- 2 files changed, 50 insertions(+), 36 deletions(-) rename pytorch_lightning/trainer/connectors/logger_connector/{epoch_loop_result.py => epoch_result_store.py} (87%) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_loop_result.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py similarity index 87% rename from pytorch_lightning/trainer/connectors/logger_connector/epoch_loop_result.py rename to pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 1196a35994a56..8745bd27d59e1 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_loop_result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -15,7 +15,7 @@ from collections import defaultdict from copy import deepcopy from enum import Enum -from typing import Union, Tuple, Any +from typing import Union, Tuple, Any, Mapping from pytorch_lightning.core.step_result import Result @@ -118,7 +118,7 @@ def get_batch_pbar_metrics(self, latest=True, *args, **kwargs): def get_batch_log_metrics(self, latest=True, *args, **kwargs): return self.get_lastest_from_func_name("get_batch_log_metrics", *args, latest=latest, **kwargs) - def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs): + def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs) -> None: if isinstance(opt_metric, Result): func = getattr(opt_metric, func_name) metrics_to_log = func( @@ -129,7 +129,7 @@ def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs): else: raise Exception("The provided opt_metric should be a Result Object. Something is wrong") - def get_epoch_from_func_name(self, func_name, *args, **kwargs): + def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> Mapping: results = {} for dl_idx in range(self.num_dataloaders): dl_idx = str(dl_idx) @@ -141,17 +141,17 @@ def get_epoch_from_func_name(self, func_name, *args, **kwargs): self.run_epoch_func(results, opt_metrics, func_name, *args, **kwargs) return results - def get_epoch_pbar_metrics(self, *args, **kwargs): + def get_epoch_pbar_metrics(self, *args, **kwargs) -> Mapping: return self.get_epoch_from_func_name("get_epoch_pbar_metrics") - def get_epoch_log_metrics(self, *args, **kwargs): + def get_epoch_log_metrics(self, *args, **kwargs) -> Mapping: return self.get_epoch_from_func_name("get_epoch_log_metrics") - def get_forked_metrics(self, *args, **kwargs): + def get_forked_metrics(self, *args, **kwargs) -> Mapping: return self.get_epoch_from_func_name("get_forked_metrics") @staticmethod - def _append_to_structure(primary_dict, opt_idx, batch_idx, result): + def _append_to_structure(primary_dict, opt_idx, batch_idx, result) -> None: if opt_idx not in primary_dict: primary_dict[opt_idx] = {} @@ -160,7 +160,7 @@ def _append_to_structure(primary_dict, opt_idx, batch_idx, result): primary_dict[opt_idx][batch_idx].append(result) - def append(self, result, dataloader_idx=None, extra_info: dict = {}): + def append(self, result, dataloader_idx=None, extra_info: dict = {}) -> None: assert isinstance(result, Result) @@ -190,7 +190,12 @@ def append(self, result, dataloader_idx=None, extra_info: dict = {}): self._internals[primary_key] = [] self._internals[primary_key].append(result) - def auto_reduce_results_on_epoch_end(self): + def auto_reduce_results_on_epoch_end(self) -> None: + """ + This function is called to reduce `self._internals` Result object. + The reduced Result object will be saved into `self._internals_reduced` + The `self._internals` stored Result objects will be deleted to save memory. + """ if not self.has_reduced: epoch_log_metrics = {} epoch_progress_bar_metrics = {} @@ -299,24 +304,36 @@ def __getitem__(self, key: str) -> Any: @property def has_split_and_opt_idx(self): + """ + This function informs if we are running within training batch loop + """ if self._split_idx is not None and self._opt_idx is not None: return True return False @property def extra_info(self): + """ + This function provides necessary parameters to properly configure HookResultStore obj + """ return {"batch_idx": self.trainer.batch_idx, "split_idx": self._split_idx, "opt_idx": self._opt_idx} def reset_model(self): - # reset model to its capture state + """ + This function is used to reset model state at the end of the capture + """ model_ref = self.trainer.get_model() model_ref._results = Result() model_ref._current_hook_fx_name = None model_ref._current_fx_name = '' def current_model_info(self): + """ + This function is used to extract + information related to current function scoping `self.log` call. + """ model_ref = self.trainer.get_model() # extract hook information fx_name = model_ref._current_hook_fx_name @@ -325,7 +342,7 @@ def current_model_info(self): dataloader_idx = model_ref._current_dataloader_idx return fx_name, dataloader_idx - def cache_result(self): + def cache_result(self) -> None: """ This function is called after every hook and store the result object @@ -363,7 +380,7 @@ def cache_result(self): # reset _results, fx_name self.reset_model() - def update_logger_connector(self, fx_name=None): + def update_logger_connector(self, fx_name: str = None) -> None: """ This function is called every time we capture a hook It automatically updates the logger_connector followings: @@ -411,34 +428,31 @@ def update_logger_connector(self, fx_name=None): logger_connector.callback_metrics.update(callback_metrics) logger_connector.callback_metrics.pop("epoch", None) - def get_latest_batch_log_metrics(self): + def run_batch_from_func_name(self, func_name) -> Mapping: results = {} for fx_name, hook_result in self._internals.items(): - results.update(hook_result.get_batch_log_metrics( - latest=True, - include_forked_originals=False)) + func = getattr(hook_result, func_name) + results.update(func(latest=True, include_forked_originals=False)) return results - def get_latest_batch_pbar_metrics(self): - results = {} - for fx_name, hook_result in self._internals.items(): - results.update(hook_result.get_batch_pbar_metrics( - latest=True, - include_forked_originals=False)) - return results + def get_latest_batch_log_metrics(self) -> Mapping: + return self.run_batch_from_func_name("get_batch_log_metrics") + + def get_latest_batch_pbar_metrics(self) -> Mapping: + return self.run_batch_from_func_name("get_batch_pbar_metrics") @property - def has_reduced(self): + def has_reduced(self) -> bool: hook_results = self._internals.values() return len(hook_results) == sum([h.has_reduced for h in hook_results]) - def auto_reduce_results_on_epoch_end(self): + def auto_reduce_results_on_epoch_end(self) -> None: if not self.has_reduced: for fx_name, hook_result in self._internals.items(): hook_result.auto_reduce_results_on_epoch_end() @property - def has_batch_loop_finished(self): + def has_batch_loop_finished(self) -> bool: return self._has_batch_loop_finished @has_batch_loop_finished.setter @@ -453,7 +467,7 @@ def has_batch_loop_finished(self, has_batch_loop_finished): self._has_batch_loop_finished = has_batch_loop_finished self.update_logger_connector() - def run_by_func_name(self, func_name): + def run_epoch_by_func_name(self, func_name) -> Mapping: if not self.has_reduced: self.auto_reduce_results_on_epoch_end() results = {} @@ -462,17 +476,17 @@ def run_by_func_name(self, func_name): results.update(func()) return results - def get_epoch_pbar_metrics(self): - return self.run_by_func_name("get_epoch_pbar_metrics") + def get_epoch_pbar_metrics(self) -> Mapping: + return self.run_epoch_by_func_name("get_epoch_pbar_metrics") - def get_epoch_log_metrics(self): - return self.run_by_func_name("get_epoch_log_metrics") + def get_epoch_log_metrics(self) -> Mapping: + return self.run_epoch_by_func_name("get_epoch_log_metrics") - def get_forked_metrics(self): - return self.run_by_func_name("get_forked_metrics") + def get_forked_metrics(self) -> Mapping: + return self.run_epoch_by_func_name("get_forked_metrics") - def get_reduced_metrics(self): - return self.run_by_func_name("get_reduced_metrics") + def get_reduced_metrics(self) -> Mapping: + return self.run_epoch_by_func_name("get_reduced_metrics") def reset(self): self._internals = {} diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 8625a7112168f..96b0e4e77f9b1 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -24,7 +24,7 @@ from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator -from pytorch_lightning.trainer.connectors.logger_connector.epoch_loop_result import ( +from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import ( EpochResultStore, LoggerStages ) From 694ded202087415ee6a37409b1ba0133e967e8a0 Mon Sep 17 00:00:00 2001 From: chaton Date: Fri, 30 Oct 2020 17:44:44 +0000 Subject: [PATCH 09/15] Update pytorch_lightning/core/lightning.py Co-authored-by: Sean Naren --- pytorch_lightning/core/lightning.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 623e532777044..0fded64c6c8e9 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -247,10 +247,11 @@ def log( on_epoch = self.__auto_choose_log_on_epoch(on_epoch) if self._current_hook_fx_name is not None: - self.trainer.logger_connector\ - .check_logging_in_callbacks(self._current_hook_fx_name, - on_step=on_step, - on_epoch=on_epoch) + self.trainer.logger_connector.check_logging_in_callbacks( + self._current_hook_fx_name, + on_step=on_step, + on_epoch=on_epoch + ) # make sure user doesn't introduce logic for multi-dataloaders if "/dataloader_idx_" in name: From cab95b831c98bec8e0077a800dbfa0b4e2116df5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 30 Oct 2020 18:35:59 +0000 Subject: [PATCH 10/15] resolve on comments --- .../logger_connector/epoch_result_store.py | 20 ++++++++++++------- .../logger_connector/logger_connector.py | 5 +++-- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 8745bd27d59e1..f5806e10a725c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -19,6 +19,9 @@ from pytorch_lightning.core.step_result import Result +# used to map boolean to right LoggerStage values +LOOKUP_TABLE = {"1": "test", "0": "validation", "True": "test", "False": "validation"} + class LoggerStages(Enum): TRAIN = "train" @@ -26,6 +29,12 @@ class LoggerStages(Enum): TEST = "test" +class StorageType: + + INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop" + ELSE = "else" + + class HookResultStore: """ This class is defined for internal usage. @@ -55,9 +64,6 @@ class HookResultStore: Those data structures enables us to reduce properly Result object when batch loop is finished. """ - - _types = ["list", "dict"] - def __init__(self, fx_name): self._fx_name = fx_name self._internals = {} @@ -102,7 +108,7 @@ def get_lastest_from_func_name(self, func_name, *args, latest=True, **kwargs): if latest: for dl_idx in range(self.num_dataloaders): dl_idx = str(dl_idx) - if self._internal_type == self._types[0]: + if self._internal_type == StorageType.ELSE: latest_result = self._internals[dl_idx][-1] else: latest_result = self.get_latest_from_dict(dl_idx) @@ -171,7 +177,7 @@ def append(self, result, dataloader_idx=None, extra_info: dict = {}) -> None: # [dataloader_idx][optimizer_idx][training_step_idx] is a list if len(extra_info) > 0: - self._internal_type = self._types[-1] + self._internal_type = StorageType.INSIDE_BATCH_TRAIN_LOOP # initialize dictionary if primary_key not in self._internals: self._internals[primary_key] = {} @@ -185,7 +191,7 @@ def append(self, result, dataloader_idx=None, extra_info: dict = {}) -> None: # [dataloader_idx] is a list else: - self._internal_type = self._types[0] + self._internal_type = StorageType.ELSE if primary_key not in self._internals: self._internals[primary_key] = [] self._internals[primary_key].append(result) @@ -204,7 +210,7 @@ def auto_reduce_results_on_epoch_end(self) -> None: dl_idx = str(dl_idx) epoch_metrics = self._internals[dl_idx] - if self._internal_type == self._types[-1]: + if self._internal_type == StorageType.INSIDE_BATCH_TRAIN_LOOP: num_opt_idx = len(self._internals[dl_idx]) - 1 diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 96b0e4e77f9b1..f67b508f67590 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -26,13 +26,14 @@ from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import ( EpochResultStore, - LoggerStages + LoggerStages, + LOOKUP_TABLE ) class LoggerConnector: - __lookup_stages = {"1": "test", "0": "validation", "True": "test", "False": "validation"} + __lookup_stages = LOOKUP_TABLE def __init__(self, trainer): self.trainer = trainer From 0fbe5ea7bff39fc350813812de25f91d27c3d983 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 2 Nov 2020 08:23:44 +0000 Subject: [PATCH 11/15] solve pyright --- .../connectors/logger_connector/epoch_result_store.py | 8 ++++---- .../connectors/logger_connector/logger_connector.py | 10 +++++++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index f5806e10a725c..f92c411a482a1 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -496,10 +496,10 @@ def get_reduced_metrics(self) -> Mapping: def reset(self): self._internals = {} - self._dataloader_idx = None - self._split_idx = None - self._opt_idx = None - self._batch_size = None + self._dataloader_idx: Union[int, None] = None + self._split_idx: Union[int, None] = None + self._opt_idx: Union[int, None] = None + self._batch_size: Union[int, None] = None self._has_batch_loop_finished = False def __repr__(self): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index f67b508f67590..b19ae0a5e16e9 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -13,7 +13,7 @@ # limitations under the License. import os from pprint import pprint -from typing import Iterable, Union +from typing import Iterable, Union, cast from copy import deepcopy from collections import ChainMap import torch @@ -251,12 +251,15 @@ def _log_on_evaluation_epoch_end_metrics(self, epoch_logs): continue reduced_epoch_metrics = dl_metrics[0].__class__.reduce_on_epoch_end(dl_metrics) - # make the keys 'k/dl' - reduced_epoch_metrics = self.__rename_keys_by_dataloader_idx(reduced_epoch_metrics, dl_idx, num_loaders) # track the metrics logger_metrics = reduced_epoch_metrics.get_epoch_log_metrics() pbar_metrics = reduced_epoch_metrics.get_epoch_pbar_metrics() + + # make the keys 'k/dl' + logger_metrics = self.__rename_keys_by_dataloader_idx(logger_metrics, dl_idx, num_loaders) + pbar_metrics = self.__rename_keys_by_dataloader_idx(pbar_metrics, dl_idx, num_loaders) + self.logged_metrics.update(logger_metrics) self.add_progress_bar_metrics(pbar_metrics) @@ -301,6 +304,7 @@ def _track_callback_metrics(self, eval_results, using_eval_result): else: self.trainer.logger_connector.callback_metrics.update(eval_results.callback_metrics) else: + flat = {} if isinstance(eval_results, list): for eval_result in eval_results: # with a scalar return, auto set it to "val_loss" for callbacks From 49a491a5856824b9468c08aedf52ca747f7120a9 Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 2 Nov 2020 10:06:18 +0000 Subject: [PATCH 12/15] Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- .../trainer/connectors/logger_connector/epoch_result_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index f92c411a482a1..d42be2ce3cb21 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -29,7 +29,7 @@ class LoggerStages(Enum): TEST = "test" -class StorageType: +class StorageType(Enum): INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop" ELSE = "else" From ac09913184a828eea30d1ef13b34aff4b8317a16 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 2 Nov 2020 10:26:51 +0000 Subject: [PATCH 13/15] update on comments --- .../logger_connector/epoch_result_store.py | 25 ++++++++++++++++++- .../logger_connector/logger_connector.py | 14 +++++------ 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index f92c411a482a1..1d83b36331593 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -19,8 +19,31 @@ from pytorch_lightning.core.step_result import Result + # used to map boolean to right LoggerStage values -LOOKUP_TABLE = {"1": "test", "0": "validation", "True": "test", "False": "validation"} +class FrozenDict(dict): + def __init__(self, *args, **kwargs): + self._hash = None + super(FrozenDict, self).__init__(*args, **kwargs) + + def __hash__(self): + if self._hash is None: + self._hash = hash(tuple(sorted(self.items()))) # iteritems() on py2 + return self._hash + + def _immutable(self, *args, **kws): + raise TypeError('cannot change object - object is immutable') + + __setitem__ = _immutable + __delitem__ = _immutable + pop = _immutable + popitem = _immutable + clear = _immutable + update = _immutable + setdefault = _immutable + + +LOOKUP_TABLE = FrozenDict({"1": "test", "0": "validation", "True": "test", "False": "validation"}) class LoggerStages(Enum): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index b19ae0a5e16e9..5c699ecffa464 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -33,8 +33,6 @@ class LoggerConnector: - __lookup_stages = LOOKUP_TABLE - def __init__(self, trainer): self.trainer = trainer self.callback_metrics = {} @@ -52,13 +50,13 @@ def cached_results(self, stage_or_testing: Union[str, bool]) -> Union[EpochResul stages = self._stages if stage_or_testing in self._stages: return self._cached_results[stage_or_testing] - if stage_or_testing in self.__lookup_stages: + if stage_or_testing in LOOKUP_TABLE: # Acces using trainer.testing - stage = self.__lookup_stages[stage_or_testing] + stage = LOOKUP_TABLE[stage_or_testing] return self._cached_results[stage] raise MisconfigurationException( f"Provide stage_or_testing {stage_or_testing} doesn't belong either to {self._stages}" - f" or {self.__lookup_stages.keys()}" + f" or {LOOKUP_TABLE.keys()}" ) def set_stage(self, stage_or_testing: str, reset:bool = False) -> None: @@ -94,12 +92,12 @@ def _determine_stage(self, stage_or_testing: Union[str, bool]) -> str: stages = self._stages if stage_or_testing in stages: return stage_or_testing - if stage_or_testing in self.__lookup_stages: + if stage_or_testing in LOOKUP_TABLE: # Acces using trainer.testing - return self.__lookup_stages[stage_or_testing] + return LOOKUP_TABLE[stage_or_testing] raise MisconfigurationException( f"Provide stage_or_testing {stage_or_testing} doesn't belong either to {stages}" - f" or {self.__lookup_stages.keys()}" + f" or {LOOKUP_TABLE.keys()}" ) def cache_logged_metrics(self) -> Union[EpochResultStore, None]: From e6043130c0fd5dec5ba77d57130dcfe596b892c3 Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 2 Nov 2020 13:36:15 +0000 Subject: [PATCH 14/15] Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py Co-authored-by: Sean Naren --- .../trainer/connectors/logger_connector/epoch_result_store.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index f03c82e47fef0..d2ba030a8eec3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -53,7 +53,6 @@ class LoggerStages(Enum): class StorageType(Enum): - INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop" ELSE = "else" From 14f92b03a21c830833903c80c983266818633a38 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 2 Nov 2020 13:38:25 +0000 Subject: [PATCH 15/15] update on comments --- .../logger_connector/epoch_result_store.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index f03c82e47fef0..2a94990517f0c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -52,10 +52,10 @@ class LoggerStages(Enum): TEST = "test" -class StorageType(Enum): +class ResultStoreType(Enum): INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop" - ELSE = "else" + OUTSIDE_BATCH_TRAIN_LOOP = "outside_batch_train_loop" class HookResultStore: @@ -131,7 +131,7 @@ def get_lastest_from_func_name(self, func_name, *args, latest=True, **kwargs): if latest: for dl_idx in range(self.num_dataloaders): dl_idx = str(dl_idx) - if self._internal_type == StorageType.ELSE: + if self._internal_type == ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP: latest_result = self._internals[dl_idx][-1] else: latest_result = self.get_latest_from_dict(dl_idx) @@ -200,7 +200,7 @@ def append(self, result, dataloader_idx=None, extra_info: dict = {}) -> None: # [dataloader_idx][optimizer_idx][training_step_idx] is a list if len(extra_info) > 0: - self._internal_type = StorageType.INSIDE_BATCH_TRAIN_LOOP + self._internal_type = ResultStoreType.INSIDE_BATCH_TRAIN_LOOP # initialize dictionary if primary_key not in self._internals: self._internals[primary_key] = {} @@ -214,7 +214,7 @@ def append(self, result, dataloader_idx=None, extra_info: dict = {}) -> None: # [dataloader_idx] is a list else: - self._internal_type = StorageType.ELSE + self._internal_type = ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP if primary_key not in self._internals: self._internals[primary_key] = [] self._internals[primary_key].append(result) @@ -233,7 +233,7 @@ def auto_reduce_results_on_epoch_end(self) -> None: dl_idx = str(dl_idx) epoch_metrics = self._internals[dl_idx] - if self._internal_type == StorageType.INSIDE_BATCH_TRAIN_LOOP: + if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: num_opt_idx = len(self._internals[dl_idx]) - 1