From b375a2646c85d896a0c7b84c6485cf19a551df5b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 6 Sep 2020 08:04:08 -0400 Subject: [PATCH] ref: inner train loop (intermediate step) 9/n (#3368) * ref: inner train loop (intermediate step) 9/n * ref: inner train loop (intermediate step) 9/n * ref: inner train loop (intermediate step) 9/n * ref: inner train loop (intermediate step) 9/n --- pytorch_lightning/trainer/training_loop.py | 25 +++-------------- .../trainer/training_loop_temp.py | 27 +++++++++++++++++++ 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 5655b4a7f8939..70d76f683bd07 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -787,31 +787,12 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): optimizer, self.hiddens ) - using_results_obj = isinstance(opt_closure_result.training_step_output, Result) - # ------------------------------ - # POST forward bookkeeping - # ------------------------------ - batch_callback_metrics.append(opt_closure_result.training_step_output.callback_metrics) - - # add metrics to loggers - if using_results_obj: - metrics_to_log = opt_closure_result.training_step_output.batch_log_metrics - step_pbar_metrics = opt_closure_result.training_step_output.batch_pbar_metrics - else: - metrics_to_log = opt_closure_result.training_step_output.log_metrics - step_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end - - # track metrics - batch_log_metrics.append(metrics_to_log) - if len(step_pbar_metrics) > 0: - self.add_progress_bar_metrics(step_pbar_metrics) + # log metrics + self.train_loop.log_training_step_metrics(opt_closure_result, batch_callback_metrics, batch_log_metrics) # track hiddens - self.hiddens = opt_closure_result.hiddens - - if using_results_obj: - opt_closure_result.training_step_output_for_epoch_end.drop_hiddens() + self.hiddens = self.train_loop.process_hiddens(opt_closure_result) # check if loss or model weights are nan if self.terminate_on_nan: diff --git a/pytorch_lightning/trainer/training_loop_temp.py b/pytorch_lightning/trainer/training_loop_temp.py index 8f77aa9d9e46d..89f87e137e9b7 100644 --- a/pytorch_lightning/trainer/training_loop_temp.py +++ b/pytorch_lightning/trainer/training_loop_temp.py @@ -253,3 +253,30 @@ def _track_gradient_norm(self, batch_idx): grad_norm_dic = model.grad_norm( self.trainer.track_grad_norm) return grad_norm_dic + + def log_training_step_metrics(self, opt_closure_result, batch_callback_metrics, batch_log_metrics): + # track callback metrics + callback_metrics = opt_closure_result.training_step_output.callback_metrics + batch_callback_metrics.append(callback_metrics) + + # decide which metrics to log (results vs dict return) + using_results_obj = isinstance(opt_closure_result.training_step_output, Result) + if using_results_obj: + metrics_to_log = opt_closure_result.training_step_output.batch_log_metrics + step_pbar_metrics = opt_closure_result.training_step_output.batch_pbar_metrics + else: + metrics_to_log = opt_closure_result.training_step_output.log_metrics + step_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end + + # track batch log metrics + batch_log_metrics.append(metrics_to_log) + + # track progress bar metrics + if len(step_pbar_metrics) > 0: + self.trainer.add_progress_bar_metrics(step_pbar_metrics) + + def process_hiddens(self, opt_closure_result): + hiddens = opt_closure_result.hiddens + if isinstance(opt_closure_result.training_step_output, Result): + opt_closure_result.training_step_output_for_epoch_end.drop_hiddens() + return hiddens