From 1906d06fb90355a35a6446cdcc00375d7a53afb3 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 5 Sep 2020 10:21:01 -0400 Subject: [PATCH 1/6] ref: inner train loop (intermediate step) 3/n --- pytorch_lightning/trainer/training_loop.py | 65 ++---------------- .../trainer/training_loop_temp.py | 67 ++++++++++++++++++- 2 files changed, 70 insertions(+), 62 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 20d654a5dba70..eddd0e4280abc 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -934,73 +934,16 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens) """ wrap the forward step in a closure so second order methods work """ - # --------------------------- - # FORWARD (TRAINING STEP + TRAIN STEP END) - # --------------------------- - with self.profiler.profile('model_forward'): - args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) - training_step_output = self.accelerator_backend.training_step(args) - training_step_output = self.call_hook('training_step_end', training_step_output) - - # ---------------------------- - # PROCESS THE RESULT - # ---------------------------- - # format and reduce outputs accordingly - training_step_output_for_epoch_end = training_step_output - is_result_obj = isinstance(training_step_output, Result) - - # track batch size for weighted average - if is_result_obj: - training_step_output.track_batch_size(len(split_batch)) - - # don't allow EvalResult in the training_step - if isinstance(training_step_output, EvalResult): - raise MisconfigurationException('training_step cannot return EvalResult, ' - 'use a dict or TrainResult instead') - - # handle regular dicts - if not is_result_obj: - training_step_output = self.process_output(training_step_output, train=True) - - training_step_output = AttributeDict( - batch_loss=training_step_output[0], - pbar_on_batch_end=training_step_output[1], - log_metrics=training_step_output[2], - callback_metrics=training_step_output[3], - hiddens=training_step_output[4], - ) - - # if the user decides to finally reduce things in epoch_end, save raw output without graphs - if isinstance(training_step_output_for_epoch_end, torch.Tensor): - training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() - elif is_result_obj: - training_step_output_for_epoch_end = copy(training_step_output) - training_step_output_for_epoch_end.detach() - else: - training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end) - - # accumulate loss - # (if accumulate_grad_batches = 1 no effect) - closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss - closure_loss = closure_loss / self.accumulate_grad_batches - - # the loss will get scaled for amp. avoid any modifications to it - untouched_loss = closure_loss.detach().clone() + # lightning module hook + result = self.train_loop.training_step(split_batch, batch_idx, opt_idx, optimizer, hiddens) # backward pass with self.profiler.profile('model_backward'): - closure_loss = self.accelerator_backend.backward(closure_loss, optimizer, opt_idx) + result.closure_loss = self.accelerator_backend.backward(result.closure_loss, optimizer, opt_idx) # hook - self.train_loop.on_after_backward(training_step_output, batch_idx, untouched_loss) + self.train_loop.on_after_backward(result.training_step_output, batch_idx, result.untouched_loss) - # result - result = AttributeDict( - loss=untouched_loss, - training_step_output=training_step_output, - training_step_output_for_epoch_end=training_step_output_for_epoch_end, - hiddens=training_step_output.hiddens, - ) return result def build_train_args(self, batch, batch_idx, opt_idx, hiddens): diff --git a/pytorch_lightning/trainer/training_loop_temp.py b/pytorch_lightning/trainer/training_loop_temp.py index 5164a4395c625..82d15e01dfb83 100644 --- a/pytorch_lightning/trainer/training_loop_temp.py +++ b/pytorch_lightning/trainer/training_loop_temp.py @@ -5,8 +5,12 @@ from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.trainer.supporters import Accumulator from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.core.step_result import Result from pytorch_lightning import _logger as log +from pytorch_lightning.utilities.memory import recursive_detach +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.core.step_result import EvalResult, Result +from pytorch_lightning.utilities.parsing import AttributeDict +from copy import copy class TrainLoop: @@ -143,3 +147,64 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss): # when in dev debugging track the losses self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach()) + + def training_step(self, split_batch, batch_idx, opt_idx, hiddens): + with self.trainer.profiler.profile('model_forward'): + args = self.trainer.build_train_args(split_batch, batch_idx, opt_idx, hiddens) + training_step_output = self.trainer.accelerator_backend.training_step(args) + training_step_output = self.trainer.call_hook('training_step_end', training_step_output) + + # ---------------------------- + # PROCESS THE RESULT + # ---------------------------- + # format and reduce outputs accordingly + training_step_output_for_epoch_end = training_step_output + is_result_obj = isinstance(training_step_output, Result) + + # track batch size for weighted average + if is_result_obj: + training_step_output.track_batch_size(len(split_batch)) + + # don't allow EvalResult in the training_step + if isinstance(training_step_output, EvalResult): + raise MisconfigurationException('training_step cannot return EvalResult, ' + 'use a dict or TrainResult instead') + + # handle regular dicts + if not is_result_obj: + training_step_output = self.trainer.process_output(training_step_output, train=True) + + training_step_output = AttributeDict( + batch_loss=training_step_output[0], + pbar_on_batch_end=training_step_output[1], + log_metrics=training_step_output[2], + callback_metrics=training_step_output[3], + hiddens=training_step_output[4], + ) + + # if the user decides to finally reduce things in epoch_end, save raw output without graphs + if isinstance(training_step_output_for_epoch_end, torch.Tensor): + training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() + elif is_result_obj: + training_step_output_for_epoch_end = copy(training_step_output) + training_step_output_for_epoch_end.detach() + else: + training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end) + + # accumulate loss + # (if accumulate_grad_batches = 1 no effect) + closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss + closure_loss = closure_loss / self.trainer.accumulate_grad_batches + + # the loss will get scaled for amp. avoid any modifications to it + untouched_loss = closure_loss.detach().clone() + + # result + result = AttributeDict( + closure_loss=closure_loss, + loss=untouched_loss, + training_step_output=training_step_output, + training_step_output_for_epoch_end=training_step_output_for_epoch_end, + hiddens=training_step_output.hiddens, + ) + return result From 7a570427359c1bf801d0e204a5adac61b2267892 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 5 Sep 2020 10:24:19 -0400 Subject: [PATCH 2/6] ref: inner train loop (intermediate step) 3/n --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index eddd0e4280abc..d4138ad2fddeb 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -935,7 +935,7 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens) wrap the forward step in a closure so second order methods work """ # lightning module hook - result = self.train_loop.training_step(split_batch, batch_idx, opt_idx, optimizer, hiddens) + result = self.train_loop.training_step(split_batch, batch_idx, opt_idx, hiddens) # backward pass with self.profiler.profile('model_backward'): From 4528b8e178f1d1c2abe2bb4e3ccfd00dc1ede2cc Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 5 Sep 2020 10:26:45 -0400 Subject: [PATCH 3/6] ref: inner train loop (intermediate step) 3/n --- pytorch_lightning/trainer/training_loop.py | 3 +-- pytorch_lightning/trainer/training_loop_temp.py | 5 +++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d4138ad2fddeb..29d5850b11a8f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -938,8 +938,7 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens) result = self.train_loop.training_step(split_batch, batch_idx, opt_idx, hiddens) # backward pass - with self.profiler.profile('model_backward'): - result.closure_loss = self.accelerator_backend.backward(result.closure_loss, optimizer, opt_idx) + self.train_loop.backward(result, optimizer, opt_idx) # hook self.train_loop.on_after_backward(result.training_step_output, batch_idx, result.untouched_loss) diff --git a/pytorch_lightning/trainer/training_loop_temp.py b/pytorch_lightning/trainer/training_loop_temp.py index 82d15e01dfb83..5591d135ee180 100644 --- a/pytorch_lightning/trainer/training_loop_temp.py +++ b/pytorch_lightning/trainer/training_loop_temp.py @@ -134,6 +134,11 @@ def get_optimizers_iterable(self): opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) return [(opt_idx, self.trainer.optimizers[opt_idx])] + def backward(self, result, optimizer, opt_idx): + # backward pass + with self.trainer.profiler.profile('model_backward'): + result.closure_loss = self.trainer.accelerator_backend.backward(result.closure_loss, optimizer, opt_idx) + def on_after_backward(self, training_step_output, batch_idx, untouched_loss): is_result_obj = isinstance(training_step_output, Result) From 6a1fdefee47a61d7820aef7e2df8b7b347f3337a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 5 Sep 2020 10:30:28 -0400 Subject: [PATCH 4/6] ref: inner train loop (intermediate step) 3/n --- pytorch_lightning/trainer/training_loop_temp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop_temp.py b/pytorch_lightning/trainer/training_loop_temp.py index 5591d135ee180..6df74977726f2 100644 --- a/pytorch_lightning/trainer/training_loop_temp.py +++ b/pytorch_lightning/trainer/training_loop_temp.py @@ -137,7 +137,7 @@ def get_optimizers_iterable(self): def backward(self, result, optimizer, opt_idx): # backward pass with self.trainer.profiler.profile('model_backward'): - result.closure_loss = self.trainer.accelerator_backend.backward(result.closure_loss, optimizer, opt_idx) + result.closure_loss = self.trainer.accelerator_backend.backward(result.loss, optimizer, opt_idx) def on_after_backward(self, training_step_output, batch_idx, untouched_loss): is_result_obj = isinstance(training_step_output, Result) From b2705cf8703ed9fa340d44e0fc3944d65783659e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 5 Sep 2020 10:32:59 -0400 Subject: [PATCH 5/6] ref: inner train loop (intermediate step) 3/n --- pytorch_lightning/trainer/training_loop_temp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop_temp.py b/pytorch_lightning/trainer/training_loop_temp.py index 6df74977726f2..5591d135ee180 100644 --- a/pytorch_lightning/trainer/training_loop_temp.py +++ b/pytorch_lightning/trainer/training_loop_temp.py @@ -137,7 +137,7 @@ def get_optimizers_iterable(self): def backward(self, result, optimizer, opt_idx): # backward pass with self.trainer.profiler.profile('model_backward'): - result.closure_loss = self.trainer.accelerator_backend.backward(result.loss, optimizer, opt_idx) + result.closure_loss = self.trainer.accelerator_backend.backward(result.closure_loss, optimizer, opt_idx) def on_after_backward(self, training_step_output, batch_idx, untouched_loss): is_result_obj = isinstance(training_step_output, Result) From 6df11fbc15e3bc13215a61174d6050fec62da851 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 5 Sep 2020 10:38:14 -0400 Subject: [PATCH 6/6] ref: inner train loop (intermediate step) 3/n --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 29d5850b11a8f..361d063baf106 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -941,7 +941,7 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens) self.train_loop.backward(result, optimizer, opt_idx) # hook - self.train_loop.on_after_backward(result.training_step_output, batch_idx, result.untouched_loss) + self.train_loop.on_after_backward(result.training_step_output, batch_idx, result.loss) return result