From 9c8701f2e2c09a984ac0322942c8f742a7cc6bc9 Mon Sep 17 00:00:00 2001 From: chaton Date: Thu, 5 Nov 2020 22:27:04 +0000 Subject: [PATCH] [feat] Logging refactor 2/n - train (#4495) * update logging * solve more bugs * replace Mapping by Dict * update on comments * resolve pep8 * Apply suggestions from code review Co-authored-by: ananthsub * Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py Co-authored-by: Jirka Borovec * update on comments * typo * update for coverage * update test * update * Update tests/models/test_hooks.py Co-authored-by: Sean Naren * Update tests/models/test_hooks.py Co-authored-by: Sean Naren * update on comments * remove deepcopy * remove useless look for * another small optim * extra optim * remove lastest optim, can be source of bug * resolve bug * add docstring * optimize coverage * Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py Co-authored-by: Jirka Borovec * Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py Co-authored-by: Jirka Borovec * Update tests/trainer/logging_tests/test_distributed_logging.py Co-authored-by: Jirka Borovec * Update pytorch_lightning/trainer/evaluation_loop.py Co-authored-by: Jirka Borovec * Update tests/trainer/logging/test_logger_connector.py Co-authored-by: Jirka Borovec * Update tests/trainer/logging_tests/test_train_loop_logging_1_0.py Co-authored-by: Jirka Borovec * update on comments * update * update on comments * update parity speed * get it down to 0.65 * update * 0.8 max_dif Co-authored-by: Jirka Borovec Co-authored-by: ananthsub Co-authored-by: Sean Naren Co-authored-by: William Falcon --- benchmarks/test_parity.py | 2 +- .../callback_hook_validator.py | 2 +- .../logger_connector/epoch_result_store.py | 252 ++++++++++++------ .../logger_connector/logger_connector.py | 137 +++++++--- pytorch_lightning/trainer/evaluation_loop.py | 3 + pytorch_lightning/trainer/trainer.py | 24 +- pytorch_lightning/trainer/training_loop.py | 113 +++----- tests/base/develop_utils.py | 2 +- tests/models/test_hooks.py | 8 +- .../test_eval_loop_dict_return.py | 15 ++ .../test_trainer_steps_dict_return.py | 17 +- .../test_trainer_steps_scalar_return.py | 9 +- .../trainer/logging/test_logger_connector.py | 192 +++++++++++-- .../logging_tests/test_distributed_logging.py | 5 +- .../test_train_loop_logging_1_0.py | 209 ++++++++++++++- 15 files changed, 733 insertions(+), 257 deletions(-) diff --git a/benchmarks/test_parity.py b/benchmarks/test_parity.py index d2b30afb23946..d2bc97deff598 100644 --- a/benchmarks/test_parity.py +++ b/benchmarks/test_parity.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize('cls_model,max_diff', [ (ParityModuleRNN, 0.05), - (ParityModuleMNIST, 0.70) + (ParityModuleMNIST, 0.8) ]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_pytorch_parity(tmpdir, cls_model, max_diff): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index 3ce4b523545c3..e9c33cea70b8a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -192,7 +192,7 @@ def _on_validation_start_log(): @staticmethod def _on_validation_end_log(): """Called when the validation loop ends.""" - return {"on_step": [False], "on_epoch": [False, True]} + return None @staticmethod def _on_test_start_log(): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 2a9d68807e694..2980b037c95f7 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -11,12 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from collections import defaultdict -from copy import deepcopy +import os +from collections import defaultdict, ChainMap from enum import Enum -from typing import Union, Tuple, Any, Mapping - +from typing import Union, Tuple, Any, Dict, Optional, List +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.core.step_result import Result @@ -92,73 +91,70 @@ def __init__(self, fx_name): self._internals_reduced = {} self._internal_type = None self.has_reduced = False + self._latest_ref = {} - def get_reduced_metrics(self): - return self._internals_reduced - - def add_dataloader_idx(self): - return len(self._internals) > 1 + @property + def has_several_dataloaders(self) -> bool: + return self.num_dataloaders > 1 @property - def num_dataloaders(self): - return len(self._internals) - - def get_latest_from_dict(self, dl_idx): - num_opt_idx = len(self._internals[dl_idx]) - 1 - assert num_opt_idx >= 0 - num_opt_idx = str(num_opt_idx) - num_batch_idx = len(self._internals[dl_idx][num_opt_idx]) - 1 - batch_indexes = [*self._internals[dl_idx][num_opt_idx].keys()] - # sort them by increasing order - batch_indexes.sort(key=float) - assert num_batch_idx >= 0 - return self._internals[dl_idx][num_opt_idx][batch_indexes[-1]][-1] + def num_dataloaders(self) -> int: + _inter = self._internals_reduced if self.has_reduced else self._internals + return len(_inter) def check_dataloader_idx(self, result: Result) -> bool: - add_dataloader_idx = False - try: - if len(result.keys()) > 1: - random_key = [*result.keys()][-1] - add_dataloader_idx = result["meta"][random_key]["dataloader_idx"] is not None - return add_dataloader_idx - return add_dataloader_idx - except Exception: - return add_dataloader_idx - - def get_lastest_from_func_name(self, func_name, *args, latest=True, **kwargs): + random_key = [*result.keys()][-1] + add_dataloader_idx = result["meta"][random_key]["dataloader_idx"] is not None + return add_dataloader_idx + + def get_lastest_from_func_name(self, latest_result, func_name: str, *args, **kwargs) -> Dict: results = {} - if latest: - for dl_idx in range(self.num_dataloaders): - dl_idx = str(dl_idx) - if self._internal_type == ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP: - latest_result = self._internals[dl_idx][-1] - else: - latest_result = self.get_latest_from_dict(dl_idx) - add_dataloader_idx = self.check_dataloader_idx(latest_result) - func = getattr(latest_result, func_name) - results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs)) - return results - raise NotImplementedError + add_dataloader_idx = self.check_dataloader_idx(latest_result) + func = getattr(latest_result, func_name) + results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs)) + return results - def get_batch_pbar_metrics(self, latest=True, *args, **kwargs): - return self.get_lastest_from_func_name("get_batch_pbar_metrics", *args, latest=latest, **kwargs) + def run_lastest_batch_metrics_with_func_name(self, func_name, *args, **kwargs) -> List[Dict]: + """ + This function used cache_ref and cache_result to optimize loading metrics - def get_batch_log_metrics(self, latest=True, *args, **kwargs): - return self.get_lastest_from_func_name("get_batch_log_metrics", *args, latest=latest, **kwargs) + Context: As we update the logger_connector metrics on every `self.log` call, + and it can be pretty time consuming, especially when logging outside batch loop. + + HookResultStore keeps track of its latest added result object, + and cache its pbar and log metrics if already called on, + """ + results = [] + for dl_idx in range(self.num_dataloaders): + dl_idx = str(dl_idx) + latest_result = self._latest_ref[dl_idx] + result = self.get_lastest_from_func_name(latest_result, func_name, *args, **kwargs) + results.append(result) + return results + + def get_batch_pbar_metrics(self, *args, **kwargs): + return self.run_lastest_batch_metrics_with_func_name("get_batch_pbar_metrics", + *args, + **kwargs) + + def get_batch_log_metrics(self, *args, **kwargs): + return self.run_lastest_batch_metrics_with_func_name("get_batch_log_metrics", + *args, + **kwargs) def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs) -> None: if isinstance(opt_metric, Result): func = getattr(opt_metric, func_name) metrics_to_log = func( *args, - add_dataloader_idx=self.add_dataloader_idx, + add_dataloader_idx=self.has_several_dataloaders, **kwargs) - results.update(metrics_to_log) + results.append(metrics_to_log) else: raise Exception("The provided opt_metric should be a Result Object. Something is wrong") - def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> Mapping: - results = {} + def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> List[Dict]: + results = [] for dl_idx in range(self.num_dataloaders): dl_idx = str(dl_idx) opt_metrics = self._internals_reduced[dl_idx] @@ -169,13 +165,13 @@ def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> Mapping: self.run_epoch_func(results, opt_metrics, func_name, *args, **kwargs) return results - def get_epoch_pbar_metrics(self, *args, **kwargs) -> Mapping: + def get_epoch_pbar_metrics(self, *args, **kwargs) -> List[Dict]: return self.get_epoch_from_func_name("get_epoch_pbar_metrics") - def get_epoch_log_metrics(self, *args, **kwargs) -> Mapping: + def get_epoch_log_metrics(self, *args, **kwargs) -> List[Dict]: return self.get_epoch_from_func_name("get_epoch_log_metrics") - def get_forked_metrics(self, *args, **kwargs) -> Mapping: + def get_forked_metrics(self, *args, **kwargs) -> List[Dict]: return self.get_epoch_from_func_name("get_forked_metrics") @staticmethod @@ -211,6 +207,8 @@ def append(self, result, dataloader_idx=None, extra_info: dict = {}) -> None: self._append_to_structure(self._internals[primary_key], opt_idx, batch_idx, result) + self._latest_ref[primary_key] = result + # [dataloader_idx] is a list else: self._internal_type = ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP @@ -218,6 +216,8 @@ def append(self, result, dataloader_idx=None, extra_info: dict = {}) -> None: self._internals[primary_key] = [] self._internals[primary_key].append(result) + self._latest_ref[primary_key] = result + def auto_reduce_results_on_epoch_end(self) -> None: """ This function is called to reduce `self._internals` Result object. @@ -271,7 +271,7 @@ def auto_reduce_results_on_epoch_end(self) -> None: self._internals_reduced[dl_idx][str(opt_idx)] = opt_outputs # free memory - del self._internals[dl_idx] + del self._internals[dl_idx][opt_idx] else: # no need to reduce as called only once if len(epoch_metrics) == 1: @@ -301,13 +301,9 @@ def __repr__(self): class EpochResultStore: """ This class is defined for internal usage. - It holds all metrics logged using the self.log function using `HookResultStore` object. - The internal datastructure is as follow: - self._internals = {"fx_name_0": HookResultStore(), ..., "fx_name_n": HookResultStore()} - Pseudo Code Example: ``` model._current_fx_name = 'something' @@ -315,7 +311,6 @@ class EpochResultStore: model.log('a', ...) epoch_result_store.cache_result() ``` - """ def __init__(self, trainer, stage): self.trainer = trainer @@ -365,7 +360,7 @@ def current_model_info(self): model_ref = self.trainer.get_model() # extract hook information fx_name = model_ref._current_hook_fx_name - if fx_name == '': + if fx_name is None: fx_name = model_ref._current_fx_name dataloader_idx = model_ref._current_dataloader_idx return fx_name, dataloader_idx @@ -398,7 +393,7 @@ def cache_result(self) -> None: Result.attach_batch_size(self._batch_size, hook_result) self._internals[fx_name].append( - deepcopy(hook_result), + hook_result, dataloader_idx=dataloader_idx, extra_info=extra_info) @@ -456,18 +451,22 @@ def update_logger_connector(self, fx_name: str = None) -> None: logger_connector.callback_metrics.update(callback_metrics) logger_connector.callback_metrics.pop("epoch", None) - def run_batch_from_func_name(self, func_name) -> Mapping: - results = {} + def run_batch_from_func_name(self, func_name) -> Dict: + results = [] for fx_name, hook_result in self._internals.items(): func = getattr(hook_result, func_name) - results.update(func(latest=True, include_forked_originals=False)) - return results + results.append(func(include_forked_originals=False)) + return dict(ChainMap(*sum(results, []))) - def get_latest_batch_log_metrics(self) -> Mapping: - return self.run_batch_from_func_name("get_batch_log_metrics") + def get_latest_batch_log_metrics(self) -> Dict: + batch_log_metrics = self.run_batch_from_func_name("get_batch_log_metrics") + batch_log_metrics.update(self.legacy_batch_log_metrics) + return batch_log_metrics - def get_latest_batch_pbar_metrics(self) -> Mapping: - return self.run_batch_from_func_name("get_batch_pbar_metrics") + def get_latest_batch_pbar_metrics(self) -> Dict: + batch_pbar_metrics = self.run_batch_from_func_name("get_batch_pbar_metrics") + batch_pbar_metrics.update(self.legacy_batch_pbar_metrics) + return batch_pbar_metrics @property def has_reduced(self) -> bool: @@ -495,27 +494,24 @@ def has_batch_loop_finished(self, has_batch_loop_finished): self._has_batch_loop_finished = has_batch_loop_finished self.update_logger_connector() - def run_epoch_by_func_name(self, func_name) -> Mapping: + def run_epoch_by_func_name(self, func_name) -> Dict: if not self.has_reduced: self.auto_reduce_results_on_epoch_end() - results = {} + results = [] for fx_name, hook_result in self._internals.items(): func = getattr(hook_result, func_name) - results.update(func()) - return results + results.append(func()) + return dict(ChainMap(*sum(results, []))) - def get_epoch_pbar_metrics(self) -> Mapping: + def get_epoch_pbar_metrics(self) -> Dict: return self.run_epoch_by_func_name("get_epoch_pbar_metrics") - def get_epoch_log_metrics(self) -> Mapping: + def get_epoch_log_metrics(self) -> Dict: return self.run_epoch_by_func_name("get_epoch_log_metrics") - def get_forked_metrics(self) -> Mapping: + def get_forked_metrics(self) -> Dict: return self.run_epoch_by_func_name("get_forked_metrics") - def get_reduced_metrics(self) -> Mapping: - return self.run_epoch_by_func_name("get_reduced_metrics") - def reset(self): self._internals = {} self._dataloader_idx: Union[int, None] = None @@ -523,6 +519,96 @@ def reset(self): self._opt_idx: Union[int, None] = None self._batch_size: Union[int, None] = None self._has_batch_loop_finished = False + self.legacy_batch_log_metrics = {} + self.legacy_batch_pbar_metrics = {} + + def __call__( + self, + fx_name: Optional[Union[str, int]] = None, + dl_idx: Optional[Union[str, int]] = None, + opt_idx: Optional[Union[str, int]] = None, + batch_idx: Optional[Union[str, int]] = None, + split_idx: Optional[Union[str, int]] = None, + reduced: bool = False, + ): + """ + This function is an helper to access stored data + + It access data from the HookResultStore. Please, + check its data structure for better understanding + + Data can be accessed with the following chains: + + IF REDUCED: + * IF accessing a fx_name defined in batch training loop: + fx_name -> dl_idx -> opt_idx -> batch_idx -> split_idx + * ELSE fx_name -> dl_idx -> batch_idx + ELSE: + * IF accessing a fx_name defined in batch training loop: + fx_name -> dl_idx -> opt_idx + * ELSE fx_name -> dl_idx + + Note: + As soon as a param is None, it breaks the chain and returns associated stored data. + + Example:: + + result: Result = self(fx_name="training_step", dl_idx="0", opt_idx="0", reduced=True) + result['train_loss_epoch'] # aggregated train_loss over one epoch. + + Args: + + fx_name: Hook name from ModelHooks or Callback. Example: `training_step` + + dl_idx: Dataloader idx in short. It starts from 0 to num_dataloaders - 1 + + opt_idx: Optimizer idx in short. It starts from 0 to num_optimizers - 1 + + batch_idx: Index of batch idx seen during batch training or evaluation. + Works only with reduced=False + + split_idx: Index of split idx in training loop when ttbt is used. + + reduced: Data are being aggregated on on_epoch_end. + Indicates if we want to access aggregated Result or not. + """ + + hook_result = self[str(fx_name)] + + dl_idx = str(dl_idx) if dl_idx is not None else None + opt_idx = str(opt_idx) if opt_idx is not None else None + batch_idx = str(batch_idx) if batch_idx is not None else None + split_idx = int(split_idx) if split_idx is not None else None + + internal_type = hook_result._internal_type + + if reduced: + result = hook_result._internals_reduced + else: + result = hook_result._internals + + if internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: + if not reduced: + if dl_idx is not None: + result = result[dl_idx] + if opt_idx is not None: + result = result[opt_idx] + if batch_idx is not None: + result = result[batch_idx] + if split_idx is not None: + result = result[split_idx] + else: + if dl_idx is not None: + result = result[dl_idx] + if opt_idx is not None: + result = result[opt_idx] + else: + if dl_idx is not None: + result = result[dl_idx] + if batch_idx and not reduced: + result = result[batch_idx] + + return result def __repr__(self): return f"{self.__class__.__name__}(stage={self._stage}, internals={self._internals})" diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 0a1cb836eda6d..946064660f818 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -44,25 +44,14 @@ def __init__(self, trainer): self._callback_hook_validator = CallbackHookNameValidator() self._current_stage = None - def cached_results(self, stage_or_testing: Union[str, bool]) -> Union[EpochResultStore, None]: - """ Function to access cached_results using str or bool. Bool is used only for testing""" - stage_or_testing = str(stage_or_testing) - stages = self._stages - if stage_or_testing in self._stages: - return self._cached_results[stage_or_testing] - if stage_or_testing in LOOKUP_TABLE: - # Acces using trainer.testing - stage = LOOKUP_TABLE[stage_or_testing] - return self._cached_results[stage] - raise MisconfigurationException( - f"Provide stage_or_testing {stage_or_testing} doesn't belong either to {self._stages}" - f" or {LOOKUP_TABLE.keys()}" - ) + @property + def cached_results(self) -> Union[EpochResultStore, None]: + return self._cached_results[self._current_stage] def set_stage(self, stage_or_testing: str, reset:bool = False) -> None: self._current_stage = self._determine_stage(stage_or_testing) if reset: - self.cached_results(stage_or_testing).reset() + self.cached_results.reset() def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoch: bool = None) -> None: self._callback_hook_validator.check_logging_in_callbacks(current_hook_fx_name=hook_fx_name, @@ -75,17 +64,17 @@ def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataload model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None # track batch_size - self.cached_results(testing)._batch_size = Result.extract_batch_size(batch) + self.cached_results._batch_size = Result.extract_batch_size(batch) - def on_batch_start(self, split_idx: int, opt_idx: int, split_batch) -> None: - self._cached_results["train"]._split_idx = split_idx - self._cached_results["train"]._opt_idx = opt_idx - self._cached_results["train"]._batch_size = Result.extract_batch_size(split_batch) + def on_train_split_start(self, split_idx: int, opt_idx: int, split_batch) -> None: + self.cached_results._split_idx = split_idx + self.cached_results._opt_idx = opt_idx + self.cached_results._batch_size = Result.extract_batch_size(split_batch) def on_train_batch_end(self) -> None: - self._cached_results["train"]._split_idx = None - self._cached_results["train"]._opt_idx = None - self._cached_results["train"]._batch_size = None + self.cached_results._split_idx = None + self.cached_results._opt_idx = None + self.cached_results._batch_size = None def _determine_stage(self, stage_or_testing: Union[str, bool]) -> str: stage_or_testing = str(stage_or_testing) @@ -112,6 +101,16 @@ def on_trainer_init(self, logger, flush_logs_every_n_steps, log_every_n_steps): self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps self.trainer.log_every_n_steps = log_every_n_steps + @property + def should_flush_logs(self): + should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 + return should_flush or self.trainer.should_stop + + @property + def should_update_logs(self): + should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 + return should_log_every_n_steps or self.trainer.should_stop + def configure_logger(self, logger): if logger is True: version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id) @@ -130,6 +129,53 @@ def configure_logger(self, logger): else: self.trainer.logger = logger + def cache_training_step_metrics(self, opt_closure_result): + """ + This function is responsible to update + logger_connector internals metrics holder based for depreceated logging + """ + using_results_obj = isinstance(opt_closure_result.training_step_output, Result) + + # temporary dict to collect metrics + logged_metrics_tmp = {} + pbar_metrics_tmp = {} + callback_metrics_tmp = {} + + if using_results_obj: + batch_log_metrics = opt_closure_result.training_step_output.get_batch_log_metrics( + include_forked_originals=False + ) + logged_metrics_tmp.update(batch_log_metrics) + + batch_pbar_metrics = opt_closure_result.training_step_output.get_batch_pbar_metrics( + include_forked_originals=False + ) + pbar_metrics_tmp.update(batch_pbar_metrics) + + forked_metrics = opt_closure_result.training_step_output.get_forked_metrics() + callback_metrics_tmp.update(forked_metrics) + callback_metrics_tmp.update(logged_metrics_tmp) + + else: + batch_log_metrics = opt_closure_result.training_step_output.log_metrics + logged_metrics_tmp.update(batch_log_metrics) + + callback_metrics = opt_closure_result.training_step_output.callback_metrics + callback_metrics_tmp.update(callback_metrics) + + batch_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end + pbar_metrics_tmp.update(batch_pbar_metrics) + + # track progress bar metrics + if len(pbar_metrics_tmp) > 0: + self.add_progress_bar_metrics(pbar_metrics_tmp) + + self.callback_metrics.update(callback_metrics_tmp) + + # save legacy log metrics + self.logged_metrics.update(logged_metrics_tmp) + self.cached_results.legacy_batch_log_metrics.update(logged_metrics_tmp) + def log_metrics(self, metrics, grad_norm_dic, step=None): """Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, @@ -396,8 +442,9 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mod if num_loaders == 1: self.__process_eval_epoch_end_results_and_log_legacy_update(prog_bar_metrics, log_metrics, callback_metrics) - def on_train_epoch_end(self, epoch_output): - pass + def on_train_epoch_end(self): + # inform cached logger connector epoch finished + self.cached_results.has_batch_loop_finished = True def log_train_epoch_end_metrics(self, epoch_output, @@ -441,12 +488,10 @@ def log_train_epoch_end_metrics(self, # ------------------ if is_1_0_result: # lightning module hook - epoch_end_log_result = self.training_epoch_end(model, epoch_output, num_optimizers) + self.training_epoch_end(model, epoch_output, num_optimizers) # log/aggregate metrics automatically epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output) - epoch_log_metrics.update(epoch_end_log_result.get_epoch_log_metrics()) - epoch_progress_bar_metrics.update(epoch_end_log_result.get_epoch_pbar_metrics()) # TODO: deprecate 1.0 else: @@ -459,6 +504,14 @@ def log_train_epoch_end_metrics(self, ) epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics = out + # it will perform reduction over epoch and return log metrics + cached_epoch_log_metrics = self.cached_results.get_epoch_log_metrics() + cached_epoch_pbar_metrics = self.cached_results.get_epoch_pbar_metrics() + + # update + epoch_log_metrics.update(cached_epoch_log_metrics) + epoch_progress_bar_metrics.update(cached_epoch_pbar_metrics) + # -------------------------- # track results # -------------------------- @@ -475,15 +528,16 @@ def log_train_epoch_end_metrics(self, self.add_progress_bar_metrics(epoch_progress_bar_metrics) self.callback_metrics.update(epoch_progress_bar_metrics) + # reset epoch loop result for next epoch + self.cached_results.reset() + def training_epoch_end(self, model, epoch_output, num_optimizers): if not is_overridden('training_epoch_end', model=model): - return Result() + return # run training_epoch_end # refresh the result for custom logging at the epoch level model._current_fx_name = 'training_epoch_end' - model._results = Result() - epoch_output = self.__prepare_epoch_end_inputs(epoch_output) if num_optimizers == 1 or not self.trainer.train_loop.automatic_optimization: @@ -492,15 +546,11 @@ def training_epoch_end(self, model, epoch_output, num_optimizers): # lightningmodule hook epoch_output = model.training_epoch_end(epoch_output) - model._current_fx_name = '' - if epoch_output is not None: raise MisconfigurationException('training_epoch_end expects a return of None. ' 'HINT: remove the return statement in training_epoch_end') - - # user can ALSO log at the end of an epoch - new_epoch_end_logs = model._results - return new_epoch_end_logs + # capture logging + self.trainer.logger_connector.cache_logged_metrics() def __run_legacy_training_epoch_end( self, @@ -527,8 +577,12 @@ def __run_legacy_training_epoch_end( # run training_epoch_end # a list with a result per optimizer index + model._current_fx_name = 'training_epoch_end' epoch_output = model.training_epoch_end(epoch_output) + # capture logging + self.trainer.logger_connector.cache_logged_metrics() + if isinstance(epoch_output, Result): epoch_log_metrics = epoch_output.epoch_log_metrics epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics @@ -563,7 +617,7 @@ def __auto_reduce_results_on_epoch_end(self, epoch_output): # reduce across training steps opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs) - # with manual opt need 1+ metrics because meta is always there + # with manual opt need 1 + metrics because meta is always there if opt_outputs.minimize is not None: opt_outputs.minimize = opt_outputs.minimize.mean() epoch_log_metrics.update(opt_outputs.epoch_log_metrics) @@ -623,12 +677,9 @@ def __gather_result_across_time_and_optimizers(self, epoch_output): def log_train_step_metrics(self, batch_output): # when metrics should be logged - should_log_metrics = ( - (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 or self.trainer.should_stop - ) - if should_log_metrics or self.trainer.fast_dev_run: + if self.should_update_logs or self.trainer.fast_dev_run: # logs user requested information to logger - metrics = batch_output.batch_log_metrics + metrics = self.cached_results.get_latest_batch_log_metrics() grad_norm_dic = batch_output.grad_norm_dic if len(metrics) > 0 or len(grad_norm_dic) > 0: self.log_metrics(metrics, grad_norm_dic) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 89a242dbfd886..6ebab1ade0f1d 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -358,6 +358,9 @@ def __log_result_step_metrics(self, output, batch_idx): step_log_metrics = output.get_batch_log_metrics(include_forked_originals=False) step_pbar_metrics = output.get_batch_pbar_metrics(include_forked_originals=False) + cached_batch_log_metrics = \ + self.trainer.logger_connector.cached_results.get_latest_batch_log_metrics() + if len(step_log_metrics) > 0: # make the metrics appear as a different line in the same graph metrics_by_epoch = {} diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 49cf232f76ac7..2d4e2c0d9e4bd 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -838,7 +838,25 @@ def call_setup_hook(self, model): self.setup(stage_name) model.setup(stage_name) + def _reset_result_and_set_hook_fx_name(self, hook_name): + model_ref = self.get_model() + if model_ref is not None: + # used to track current hook name called + model_ref._results = Result() + model_ref._current_hook_fx_name = hook_name + + def _cache_logged_metrics(self): + model_ref = self.get_model() + if model_ref is not None: + # capture logging for this hook + self.logger_connector.cache_logged_metrics() + def call_hook(self, hook_name, *args, **kwargs): + # temporary. Don't modify evaluation behaviour + if self.logger_connector._current_stage == "train": + # set hook_name to model + reset Result obj + self._reset_result_and_set_hook_fx_name(hook_name) + # always profile hooks with self.profiler.profile(hook_name): @@ -860,4 +878,8 @@ def call_hook(self, hook_name, *args, **kwargs): accelerator_hook = getattr(self.accelerator_backend, hook_name) output = accelerator_hook(*args, **kwargs) - return output + # temporary. Don't modify evaluation behaviour + if self.logger_connector._current_stage == "train": + # capture logging + self._cache_logged_metrics() + return output diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3845b7eb728ac..2f66f5b1a600e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -251,12 +251,15 @@ def on_train_epoch_start(self, epoch): self.trainer.call_hook("on_train_epoch_start") def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx): + # hook + self.trainer.call_hook('on_batch_end') + self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch, batch_idx, dataloader_idx) + # figure out what to track for epoch end self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs) - # hook - self.trainer.call_hook("on_batch_end") - self.trainer.call_hook("on_train_batch_end", epoch_end_outputs, batch, batch_idx, dataloader_idx) + # reset batch logger internals + self.trainer.logger_connector.on_train_batch_end() def reset_train_val_dataloaders(self, model): if not self.trainer.reload_dataloaders_every_epoch: @@ -305,13 +308,16 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss): def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # give the PL module a result for logging - model = self.trainer.get_model() - model._results = Result() - model._current_fx_name = "training_step" + model_ref = self.trainer.get_model() with self.trainer.profiler.profile("model_forward"): args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) + + # manually capture logged metrics + model_ref._current_fx_name = 'training_step' training_step_output = self.trainer.accelerator_backend.training_step(args) + self.trainer.logger_connector.cache_logged_metrics() + training_step_output = self.trainer.call_hook("training_step_end", training_step_output) training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( @@ -484,35 +490,6 @@ def _track_gradient_norm(self): grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm) return grad_norm_dict - def log_training_step_metrics(self, opt_closure_result, batch_callback_metrics, batch_log_metrics): - # track callback metrics - callback_metrics = opt_closure_result.training_step_output.callback_metrics - - # decide which metrics to log (results vs dict return) - using_results_obj = isinstance(opt_closure_result.training_step_output, Result) - if using_results_obj: - metrics_to_log = opt_closure_result.training_step_output.get_batch_log_metrics( - include_forked_originals=False - ) - step_pbar_metrics = opt_closure_result.training_step_output.get_batch_pbar_metrics( - include_forked_originals=False - ) - forked_metrics = opt_closure_result.training_step_output.get_forked_metrics() - callback_metrics.update(forked_metrics) - else: - metrics_to_log = opt_closure_result.training_step_output.log_metrics - step_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end - - # track batch log metrics - batch_log_metrics.append(metrics_to_log) - - # track progress bar metrics - if len(step_pbar_metrics) > 0: - self.trainer.logger_connector.add_progress_bar_metrics(step_pbar_metrics) - self.trainer.logger_connector.callback_metrics.update(step_pbar_metrics) - - batch_callback_metrics.append(callback_metrics) - def process_hiddens(self, opt_closure_result): hiddens = opt_closure_result.hiddens if isinstance(opt_closure_result.training_step_output, Result): @@ -578,6 +555,8 @@ def run_training_epoch(self): should_check_val = self.should_check_val_fx(batch_idx, is_last_batch) if should_check_val: self.trainer.run_evaluation(test_mode=False) + # reset stage to train + self.trainer.logger_connector.set_stage("train") # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) @@ -586,7 +565,6 @@ def run_training_epoch(self): # update LR schedulers monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics) - monitor_metrics.update(batch_output.batch_log_metrics) self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) self.trainer.checkpoint_connector.has_trained = True @@ -612,19 +590,19 @@ def run_training_epoch(self): # progress global step according to grads progress self.increment_accumulated_grad_global_step() + # epoch end hook + self.run_on_epoch_end_hook(epoch_output) + # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics( - epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers + epoch_output, + self.checkpoint_accumulator, + self.early_stopping_accumulator, + self.num_optimizers ) - # hook - self.trainer.logger_connector.on_train_epoch_end(epoch_output) - # when no val loop is present or fast-dev-run still need to call checkpoints - self.check_checkpoint_callback(not (should_check_val or is_overridden("validation_step", model))) - - # epoch end hook - self.run_on_epoch_end_hook(epoch_output) + self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model))) # increment the global step once # progress global step according to grads progress @@ -634,12 +612,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): # track grad norms grad_norm_dic = {} - # track all metrics for callbacks - batch_callback_metrics = [] - - # track metrics to log - batch_log_metrics = [] - # bookkeeping using_results_obj = False self.trainer.hiddens = None @@ -683,8 +655,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) batch_outputs = self._process_closure_result( - batch_callback_metrics=batch_callback_metrics, - batch_log_metrics=batch_log_metrics, batch_outputs=batch_outputs, opt_idx=opt_idx, ) @@ -711,15 +681,18 @@ def train_step_and_backward_closure(): self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) else: - self._curr_step_result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) + self._curr_step_result = self.training_step( + split_batch, + batch_idx, + opt_idx, + self.trainer.hiddens + ) if self._curr_step_result is None: # user decided to skip optimization continue batch_outputs = self._process_closure_result( - batch_callback_metrics=batch_callback_metrics, - batch_log_metrics=batch_log_metrics, batch_outputs=batch_outputs, opt_idx=opt_idx, ) @@ -737,19 +710,9 @@ def train_step_and_backward_closure(): # update running loss + reset accumulated loss self.update_running_loss() - # collapse all metrics into one dict - batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()} - - # track all metrics for callbacks - self.trainer.logger_connector.callback_metrics.update(batch_log_metrics) - self.trainer.logger_connector.callback_metrics.update( - {k: v for d in batch_callback_metrics for k, v in d.items() if v is not None} - ) - result = AttributeDict( signal=0, grad_norm_dic=grad_norm_dic, - batch_log_metrics=batch_log_metrics, training_step_output_for_epoch_end=batch_outputs, ) return result @@ -762,14 +725,14 @@ def block_ddp_sync_behaviour(self): yield def _process_closure_result( - self, batch_callback_metrics: list, batch_log_metrics: list, batch_outputs: list, opt_idx: int + self, batch_outputs: list, opt_idx: int ) -> list: opt_closure_result = self._curr_step_result if opt_closure_result is not None: - # log metrics - self.log_training_step_metrics(opt_closure_result, batch_callback_metrics, batch_log_metrics) + # cache metrics + self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) # track hiddens self.trainer.hiddens = self.process_hiddens(opt_closure_result) @@ -842,8 +805,10 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None): self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics) def run_on_epoch_end_hook(self, epoch_output): - self.trainer.call_hook("on_epoch_end") - self.trainer.call_hook("on_train_epoch_end", epoch_output) + self.trainer.call_hook('on_epoch_end') + self.trainer.call_hook('on_train_epoch_end', epoch_output) + + self.trainer.logger_connector.on_train_epoch_end() def increment_accumulated_grad_global_step(self): num_accumulated_batches_reached = self._accumulated_batches_reached() @@ -898,10 +863,8 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): def save_loggers_on_train_batch_end(self): # when loggers should save to disk - should_save_log = ( - self.trainer.global_step + 1 - ) % self.trainer.flush_logs_every_n_steps == 0 or self.trainer.should_stop - if should_save_log or self.trainer.fast_dev_run: + should_flush_logs = self.trainer.logger_connector.should_flush_logs + if should_flush_logs or self.trainer.fast_dev_run: if self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() @@ -955,7 +918,7 @@ def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): model.toggle_optimizer(optimizer, opt_idx) # use to track metrics internally - self.trainer.logger_connector.on_batch_start(split_idx, opt_idx, split_batch) + self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch) def update_running_loss(self): accumulated_loss = self.accumulated_loss.mean() diff --git a/tests/base/develop_utils.py b/tests/base/develop_utils.py index ba0d20c2c8389..9c88ba1b7e4d3 100644 --- a/tests/base/develop_utils.py +++ b/tests/base/develop_utils.py @@ -32,7 +32,7 @@ def assert_speed_parity_relative(pl_times, pt_times, max_diff: float = 0.1): f"lightning {diffs} was slower than PT (threshold {max_diff})" -def assert_speed_parity_absolute(pl_times, pt_times, nb_epochs, max_diff: float = 0.6): +def assert_speed_parity_absolute(pl_times, pt_times, nb_epochs, max_diff: float = 0.55): # assert speeds diffs = np.asarray(pl_times) - np.asarray(pt_times) # norm by vanila time diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 886e0db4e7854..bccc5262a5bda 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -307,7 +307,7 @@ def on_test_model_train(self): trainer.fit(model) - assert model.called == [ + expected = [ 'on_fit_start', 'on_pretrain_routine_start', 'on_pretrain_routine_end', @@ -341,10 +341,12 @@ def on_test_model_train(self): 'on_fit_end', ] + assert model.called == expected + model2 = HookedModel() trainer.test(model2) - assert model2.called == [ + expected = [ 'on_fit_start', 'on_pretrain_routine_start', 'on_pretrain_routine_end', @@ -356,3 +358,5 @@ def on_test_model_train(self): 'on_test_model_train', 'on_fit_end', ] + + assert model2.called == expected diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py index 31b60e1d0be8b..6329480e10a11 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py @@ -14,6 +14,7 @@ """ Tests to ensure that the training loop works with a dict """ +import os from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning import Trainer from tests.base.deterministic_model import DeterministicModel @@ -125,6 +126,9 @@ def test_validation_step_dict_return(tmpdir): Test that val step can return a dict with all the expected keys and they end up in the correct place """ + + os.environ['PL_DEV_DEBUG'] = '0' + model = DeterministicModel() model.training_step = model.training_step_dict_return model.validation_step = model.validation_step_dict_return @@ -166,6 +170,8 @@ def test_val_step_step_end_no_return(tmpdir): """ Test that val step + val step end work (with no return in val step end) """ + os.environ['PL_DEV_DEBUG'] = '0' + model = DeterministicModel() model.training_step = model.training_step_dict_return model.validation_step = model.validation_step_dict_return @@ -197,6 +203,9 @@ def test_val_step_step_end(tmpdir): """ Test that val step + val step end work """ + + os.environ['PL_DEV_DEBUG'] = '0' + model = DeterministicModel() model.training_step = model.training_step_dict_return model.validation_step = model.validation_step_dict_return @@ -241,6 +250,9 @@ def test_no_val_step_end(tmpdir): """ Test that val step + val epoch end """ + + os.environ['PL_DEV_DEBUG'] = '0' + model = DeterministicModel() model.training_step = model.training_step_dict_return model.validation_step = model.validation_step_dict_return @@ -284,6 +296,9 @@ def test_full_val_loop(tmpdir): """ Test that val step + val step end + val epoch end """ + + os.environ['PL_DEV_DEBUG'] = '0' + model = DeterministicModel() model.training_step = model.training_step_dict_return model.validation_step = model.validation_step_dict_return diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_dict_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_dict_return.py index 7e8588ce9f6b2..8d1aaf1b3c548 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_dict_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_dict_return.py @@ -44,9 +44,10 @@ def test_training_step_dict(tmpdir): break out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) + assert out.signal == 0 - assert out.batch_log_metrics['log_acc1'] == 12.0 - assert out.batch_log_metrics['log_acc2'] == 7.0 + assert trainer.logger_connector.logged_metrics['log_acc1'] == 12.0 + assert trainer.logger_connector.logged_metrics['log_acc2'] == 7.0 train_step_out = out.training_step_output_for_epoch_end assert len(train_step_out) == 1 @@ -92,8 +93,8 @@ def training_step_with_step_end(tmpdir): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert out.batch_log_metrics['log_acc1'] == 14.0 - assert out.batch_log_metrics['log_acc2'] == 9.0 + assert trainer.logger_connector.logged_metrics['log_acc1'] == 14.0 + assert trainer.logger_connector.logged_metrics['log_acc2'] == 9.0 train_step_end_out = out.training_step_output_for_epoch_end pbar_metrics = train_step_end_out['progress_bar'] @@ -133,8 +134,8 @@ def test_full_training_loop_dict(tmpdir): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert out.batch_log_metrics['log_acc1'] == 14.0 - assert out.batch_log_metrics['log_acc2'] == 9.0 + assert trainer.logger_connector.logged_metrics['log_acc1'] == 14.0 + assert trainer.logger_connector.logged_metrics['log_acc2'] == 9.0 # get the output of the first optimizer train_step_end_out = out.training_step_output_for_epoch_end @@ -220,8 +221,8 @@ def test_train_step_epoch_end(tmpdir): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert out.batch_log_metrics['log_acc1'] == 12.0 - assert out.batch_log_metrics['log_acc2'] == 7.0 + assert trainer.logger_connector.logged_metrics['log_acc1'] == 12.0 + assert trainer.logger_connector.logged_metrics['log_acc2'] == 7.0 # outputs are for 1 optimizer and no tbptt train_step_end_out = out.training_step_output_for_epoch_end diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py index b5eae913ca428..2a66f743a49ef 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py @@ -15,6 +15,7 @@ Tests to ensure that the training loop works with a scalar """ import torch +import os from pytorch_lightning import Trainer from tests.base.deterministic_model import DeterministicModel @@ -46,7 +47,6 @@ def test_training_step_scalar(tmpdir): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) train_step_out = out.training_step_output_for_epoch_end @@ -84,7 +84,6 @@ def training_step_scalar_with_step_end(tmpdir): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) train_step_out = out.training_step_output_for_epoch_end @@ -104,6 +103,8 @@ def test_full_training_loop_scalar(tmpdir): Checks train_step + training_step_end + training_epoch_end (all with scalar return from train_step) """ + os.environ['PL_DEV_DEBUG'] = '0' + model = DeterministicModel() model.training_step = model.training_step_scalar_return model.training_step_end = model.training_step_end_scalar @@ -132,7 +133,6 @@ def test_full_training_loop_scalar(tmpdir): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) train_step_out = out.training_step_output_for_epoch_end @@ -152,6 +152,8 @@ def test_train_step_epoch_end_scalar(tmpdir): Checks train_step + training_epoch_end (NO training_step_end) (with scalar return) """ + os.environ['PL_DEV_DEBUG'] = '0' + model = DeterministicModel() model.training_step = model.training_step_scalar_return model.training_step_end = None @@ -176,7 +178,6 @@ def test_train_step_epoch_end_scalar(tmpdir): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) train_step_out = out.training_step_output_for_epoch_end diff --git a/tests/trainer/logging/test_logger_connector.py b/tests/trainer/logging/test_logger_connector.py index 0f27f2ca4fef4..08936f89eb9f8 100644 --- a/tests/trainer/logging/test_logger_connector.py +++ b/tests/trainer/logging/test_logger_connector.py @@ -17,15 +17,19 @@ import os import torch import pytest - +from copy import deepcopy from pytorch_lightning.trainer import Trainer from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector +from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore +from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel, RandomDataset class Helper: - def decorator_with_arguments(fx_name='', hook_fx_name=''): + def decorator_with_arguments(fx_name='', hook_fx_name=None): def decorator(func): def wrapper(self, *args, **kwargs): # Set information @@ -65,9 +69,9 @@ def training_step(self, batch, batch_idx): self.log("train_loss", loss, on_step=True, on_epoch=True) return {"loss": loss} - def val_dataloader(self): - return [torch.utils.data.DataLoader(RandomDataset(32, 64)), - torch.utils.data.DataLoader(RandomDataset(32, 64))] + def on_train_epoch_end(self, outputs): + # save objects as it will be reset at the end of epoch. + self.train_results = deepcopy(self.trainer.logger_connector.cached_results) model = TestModel() model.val_dataloader = None @@ -82,21 +86,31 @@ def val_dataloader(self): ) trainer.fit(model) - assert len(trainer.logger_connector.cached_results("train")['training_step']['0']['0']) == 2 - assert trainer.logger_connector.cached_results("train")['training_step']['0']['0']['0'][0]["train_loss"] == model.train_losses[0] - assert trainer.logger_connector.cached_results("train")['training_step']['0']['0']['1'][0]["train_loss"] == model.train_losses[1] + train_results = model.train_results - # assert reduction didn't happen yet - assert trainer.logger_connector.cached_results("train").has_reduced is False + assert len(train_results(fx_name="training_step", dl_idx="0", opt_idx="0")) == 2 + generated = train_results(fx_name="training_step", + dl_idx="0", + opt_idx="0", + batch_idx="0", + split_idx="0")["train_loss"] + assert generated == model.train_losses[0] + generated = train_results(fx_name="training_step", + dl_idx="0", + opt_idx="0", + batch_idx="1", + split_idx="0")["train_loss"] + assert generated == model.train_losses[1] - # Launch reduction - trainer.logger_connector.cached_results("train").has_batch_loop_finished = True + assert train_results.has_reduced is not True - # assert reduction did happen - assert trainer.logger_connector.cached_results("train").has_reduced is True + train_results.has_batch_loop_finished = True - assert trainer.logger_connector.cached_results("train")["training_step"]\ - ._internals_reduced["0"]["0"]['train_loss_epoch'].item() == torch.stack(model.train_losses).mean().item() + assert train_results.has_reduced is True + + generated = train_results(fx_name="training_step", dl_idx="0", opt_idx="0", reduced=True)['train_loss_epoch'].item() + excepted = torch.stack(model.train_losses).mean().item() + assert generated == excepted def test__logger_connector__epoch_result_store__train__ttbt(tmpdir): @@ -163,6 +177,10 @@ def train_dataloader(self): sampler=None, ) + def on_train_epoch_end(self, outputs): + # save objects as it will be reset at the end of epoch. + self.train_results = deepcopy(self.trainer.logger_connector.cached_results) + model = TestModel() model.training_epoch_end = None model.example_input_array = torch.randn(5, truncated_bptt_steps) @@ -178,19 +196,22 @@ def train_dataloader(self): ) trainer.fit(model) - assert len(trainer.logger_connector.cached_results("train")['training_step']['0']['0']['0']) == len(model.train_losses) + train_results = model.train_results + + generated = train_results(fx_name="training_step", dl_idx="0", opt_idx="0", batch_idx="0") + assert len(generated) == len(model.train_losses) # assert reduction didn't happen yet - assert trainer.logger_connector.cached_results("train").has_reduced is False + assert train_results.has_reduced is False # Launch reduction - trainer.logger_connector.cached_results("train").has_batch_loop_finished = True + train_results.has_batch_loop_finished = True # assert reduction did happen - assert trainer.logger_connector.cached_results("train").has_reduced is True + assert train_results.has_reduced is True - assert trainer.logger_connector.cached_results("train")['training_step']\ - ._internals_reduced['0']['0']["a_epoch"].item() == torch.stack(model.train_losses).mean().item() + generated = train_results(fx_name="training_step", dl_idx="0", opt_idx="0", reduced=True)['a_epoch'].item() + assert generated == torch.stack(model.train_losses).mean().item() @pytest.mark.parametrize('num_dataloaders', [1, 2]) @@ -206,11 +227,11 @@ class TestModel(BoringModel): test_losses = {} @Helper.decorator_with_arguments(fx_name="test_step") - def test_step(self, batch, batch_idx, dataloader_idx=0): + def test_step(self, batch, batch_idx, dl_idx=0): output = self.layer(batch) loss = self.loss(batch, output) - primary_key = str(dataloader_idx) + primary_key = str(dl_idx) if primary_key not in self.test_losses: self.test_losses[primary_key] = [] @@ -239,11 +260,126 @@ def test_dataloader(self): ) trainer.test(model) - assert len(trainer.logger_connector.cached_results("test")["test_step"]._internals) == num_dataloaders + test_results = trainer.logger_connector._cached_results["test"] + + generated = test_results(fx_name="test_step") + assert len(generated) == num_dataloaders + for dl_idx in range(num_dataloaders): - assert len(trainer.logger_connector.cached_results("test")["test_step"]._internals[str(dl_idx)]) == limit_test_batches - trainer.logger_connector.cached_results("test").has_batch_loop_finished = True + generated = len(test_results(fx_name="test_step", dl_idx=str(dl_idx))) + assert generated == limit_test_batches + + test_results.has_batch_loop_finished = True + for dl_idx in range(num_dataloaders): expected = torch.stack(model.test_losses[str(dl_idx)]).mean() - generated = trainer.logger_connector.cached_results("test")["test_step"]._internals_reduced[str(dl_idx)]["test_loss_epoch"] + generated = test_results(fx_name="test_step", dl_idx=str(dl_idx), reduced=True)["test_loss_epoch"] assert abs(expected.item() - generated.item()) < 1e-6 + + +def test_call_back_validator(tmpdir): + + funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')]) + + callbacks_func = [ + 'on_after_backward', + 'on_batch_end', + 'on_batch_start', + 'on_before_zero_grad', + 'on_epoch_end', + 'on_epoch_start', + 'on_fit_end', + 'on_fit_start', + 'on_init_end', 'on_init_start', + 'on_keyboard_interrupt', + 'on_load_checkpoint', + 'on_pretrain_routine_end', + 'on_pretrain_routine_start', + 'on_sanity_check_end', + 'on_sanity_check_start', + 'on_save_checkpoint', + 'on_test_batch_end', + 'on_test_batch_start', + 'on_test_end', + 'on_test_epoch_end', + 'on_test_epoch_start', + 'on_test_start', + 'on_train_batch_end', + 'on_train_batch_start', + 'on_train_end', + 'on_train_epoch_end', + 'on_train_epoch_start', + 'on_train_start', + 'on_validation_batch_end', + 'on_validation_batch_start', + 'on_validation_end', + 'on_validation_epoch_end', + 'on_validation_epoch_start', + 'on_validation_start', + 'setup', + 'teardown', + ] + + not_supported = [ + "on_fit_end", + "on_fit_start", + "on_init_end", + "on_init_start", + "on_keyboard_interrupt", + "on_load_checkpoint", + "on_pretrain_routine_end", + "on_pretrain_routine_start", + "on_sanity_check_end", + "on_sanity_check_start", + "on_save_checkpoint", + "on_test_end", + "on_train_end", + "on_validation_end", + "setup", + "teardown", + ] + + assert funcs_name == callbacks_func, """Detected new callback function. + Need to add its logging permission to CallbackHookNameValidator and update this test""" + + validator = CallbackHookNameValidator() + + for func_name in funcs_name: + # This summurize where and what is currently possible to log using `self.log` function. + is_stage = "train" in func_name or "test" in func_name or "validation" in func_name + is_start = "start" in func_name or "batch" in func_name + on_step = is_stage and is_start + on_epoch = True + # creating allowed condition + allowed = ( + is_stage + or "batch" in func_name + or "epoch" in func_name + or "grad" in func_name + or "backward" in func_name + ) + allowed = ( + allowed + and "pretrain" not in func_name + and func_name not in ["on_train_end", "on_test_end", "on_validation_end"] + ) + if allowed: + validator.check_logging_in_callbacks(current_hook_fx_name=func_name, + on_step=on_step, + on_epoch=on_epoch) + if not is_start and is_stage: + with pytest.raises(MisconfigurationException, match="function supports only"): + validator.check_logging_in_callbacks(current_hook_fx_name=func_name, + on_step=True, + on_epoch=on_epoch) + else: + assert func_name in not_supported + with pytest.raises(MisconfigurationException, match="function doesn't support"): + validator.check_logging_in_callbacks(current_hook_fx_name=func_name, + on_step=on_step, + on_epoch=on_epoch) + + result = validator.check_logging_in_callbacks(current_hook_fx_name=None, + on_step=None, + on_epoch=None) + assert result is None diff --git a/tests/trainer/logging_tests/test_distributed_logging.py b/tests/trainer/logging_tests/test_distributed_logging.py index 5fdd021dcc0ae..a600317a024c9 100644 --- a/tests/trainer/logging_tests/test_distributed_logging.py +++ b/tests/trainer/logging_tests/test_distributed_logging.py @@ -26,8 +26,9 @@ def on_pretrain_routine_end(self) -> None: with mock.patch('pytorch_lightning.loggers.base.LightningLoggerBase.agg_and_log_metrics') as m: self.trainer.logger_connector.log_metrics({'a': 2}, {}) logged_times = m.call_count - expected = 1 if self.global_rank == 0 else 0 - assert logged_times == expected, 'actual logger called from non-global zero' + expected = int(self.trainer.is_global_zero) + msg = f'actual logger called from non-global zero, logged_times: {logged_times}, expected: {expected}' + assert logged_times == expected, msg @pytest.mark.skipif(platform.system() == "Windows", diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index 414264894e639..60ff33b402e4b 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -14,15 +14,22 @@ """ Tests to ensure that the training loop works with a dict (1.0) """ -from pytorch_lightning.core.lightning import LightningModule -from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset + import os -import torch +import collections import pytest +import itertools +import numpy as np +import torch +from torch.utils.data import Dataset + +import pytorch_lightning as pl +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning import Trainer, callbacks + +from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset from tests.base.deterministic_model import DeterministicModel -from torch.utils.data import Dataset def test__training_step__log(tmpdir): @@ -324,12 +331,12 @@ def training_step(self, batch, batch_idx, hiddens): assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) - loss_val = torch.nn.functional.mse_loss( + loss = torch.nn.functional.mse_loss( pred, y_tensor.view(batch_size, truncated_bptt_steps)) - self.log('a', loss_val, on_epoch=True) + self.log('a', loss, on_epoch=True) - return {'loss': loss_val, 'hiddens': self.test_hidden} + return {'loss': loss, 'hiddens': self.test_hidden} def on_train_epoch_start(self) -> None: self.test_hidden = None @@ -398,8 +405,10 @@ def val_dataloader(self): generated = set(trainer.logger_connector.logged_metrics) expected = { + 'a_step', 'a_epoch', - 'n_step/epoch_0', 'n_epoch', + 'n_step/epoch_0', + 'n_epoch', 'epoch' } @@ -489,3 +498,187 @@ def validation_step(self, batch, batch_idx): weights_summary=None, ) trainer.fit(model, train_data, val_data) + + +def test_log_works_in_train_callback(tmpdir): + """ + Tests that log can be called within callback + """ + + os.environ['PL_DEV_DEBUG'] = '1' + + class TestCallback(callbacks.Callback): + + # helpers + count = 1 + choices = [False, True] + # used to compute expected values + callback_funcs_called = collections.defaultdict(list) + funcs_called_count = collections.defaultdict(int) + funcs_attr = {} + + def make_logging(self, pl_module: pl.LightningModule, func_name, func_idx, + on_steps=[], on_epochs=[], prob_bars=[]): + self.funcs_called_count[func_name] += 1 + for idx, (on_step, on_epoch, prog_bar) in enumerate(list(itertools.product(*[on_steps, on_epochs, prob_bars]))): + # run logging + custom_func_name = f"{func_idx}_{idx}_{func_name}" + pl_module.log(custom_func_name, self.count * func_idx, on_step=on_step, + on_epoch=on_epoch, prog_bar=prog_bar) + + # catch information for verification + + # on on_train_start is outside the main loop. Won't be called + if func_name == "on_train_start": + self.callback_funcs_called[func_name].append([self.count * func_idx]) + + # Saved only values from second epoch, so we can compute its mean or latest. + if pl_module.trainer.current_epoch == 1: + self.callback_funcs_called[func_name].append([self.count * func_idx]) + + forked = on_step and on_epoch + + self.funcs_attr[custom_func_name] = { + "on_step": on_step, + "on_epoch": on_epoch, + "prog_bar": prog_bar, + "forked": forked, + "func_name": func_name} + + if on_step and on_epoch: + self.funcs_attr[f"{custom_func_name}_step"] = { + "on_step": True, + "on_epoch": False, + "prog_bar": prog_bar, + "forked": False, + "func_name": func_name} + + self.funcs_attr[f"{custom_func_name}_epoch"] = { + "on_step": False, + "on_epoch": True, + "prog_bar": prog_bar, + "forked": False, + "func_name": func_name} + + def on_train_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_train_start', 1, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_epoch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_epoch_start', 2, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_train_epoch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_train_epoch_start', 3, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_batch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_batch_start', 4, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + self.make_logging(pl_module, 'on_train_batch_start', 5, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_batch_end(self, trainer, pl_module): + self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self.make_logging(pl_module, 'on_train_batch_end', 7, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + # used to make sure aggregation works fine. + # we should obtain func[value * c for c in range(1, max_epochs * limit_train_batches)]) + # with func = np.mean if on_epoch else func = np.max + self.count += 1 + + def on_epoch_end(self, trainer, pl_module): + self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False], + on_epochs=self.choices, prob_bars=self.choices) + + def on_train_epoch_end(self, trainer, pl_module, outputs): + self.make_logging(pl_module, 'on_train_epoch_end', 9, on_steps=[False], + on_epochs=self.choices, prob_bars=self.choices) + + class TestModel(BoringModel): + + manual_loss = [] + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.manual_loss.append(loss) + self.log('train_loss', loss) + return {"loss": loss} + + max_epochs = 2 + limit_train_batches = 2 + model = TestModel() + test_callback = TestCallback() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=0, + limit_test_batches=0, + val_check_interval=0., + num_sanity_val_steps=0, + max_epochs=max_epochs, + callbacks=[test_callback] + ) + trainer.fit(model) + + assert test_callback.funcs_called_count["on_train_start"] == 1 + assert test_callback.funcs_called_count["on_epoch_start"] == 2 + assert test_callback.funcs_called_count["on_train_epoch_start"] == 2 + assert test_callback.funcs_called_count["on_batch_start"] == 4 + assert test_callback.funcs_called_count["on_train_batch_start"] == 4 + assert test_callback.funcs_called_count["on_batch_end"] == 4 + assert test_callback.funcs_called_count["on_train_batch_end"] == 4 + assert test_callback.funcs_called_count["on_epoch_end"] == 2 + assert test_callback.funcs_called_count["on_train_epoch_end"] == 2 + + # Make sure the func_name exists within callback_metrics. If not, we missed some + callback_metrics_keys = [*trainer.callback_metrics.keys()] + for func_name in test_callback.callback_funcs_called.keys(): + is_in = False + for callback_metrics_key in callback_metrics_keys: + if func_name in callback_metrics_key: + is_in = True + assert is_in, (func_name, callback_metrics_keys) + + # function used to describe expected return logic + def get_expected_output(func_attr, original_values): + if func_attr["on_epoch"] and not func_attr["on_step"]: + # Apply mean on values + expected_output = np.mean(original_values) + else: + # Keep the latest value + expected_output = np.max(original_values) + return expected_output + + # Make sure the func_name output equals the average from all logged values when on_epoch true + # pop extra keys + trainer.callback_metrics.pop("debug_epoch") + assert trainer.logged_metrics["train_loss"] == model.manual_loss[-1] + assert trainer.callback_metrics["train_loss"] == model.manual_loss[-1] + trainer.callback_metrics.pop("train_loss") + + for func_name, output_value in trainer.callback_metrics.items(): + if torch.is_tensor(output_value): + output_value = output_value.item() + # get creation attr + func_attr = test_callback.funcs_attr[func_name] + + # retrived orginal logged values + original_values = test_callback.callback_funcs_called[func_attr["func_name"]] + + # compute expected output and compare to actual one + expected_output = get_expected_output(func_attr, original_values) + assert float(output_value) == float(expected_output) + + for func_name, func_attr in test_callback.funcs_attr.items(): + if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: + assert func_name in trainer.logger_connector.progress_bar_metrics + else: + assert func_name not in trainer.logger_connector.progress_bar_metrics