diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 20d654a5dba70..361d063baf106 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -934,73 +934,15 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens) """ wrap the forward step in a closure so second order methods work """ - # --------------------------- - # FORWARD (TRAINING STEP + TRAIN STEP END) - # --------------------------- - with self.profiler.profile('model_forward'): - args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) - training_step_output = self.accelerator_backend.training_step(args) - training_step_output = self.call_hook('training_step_end', training_step_output) - - # ---------------------------- - # PROCESS THE RESULT - # ---------------------------- - # format and reduce outputs accordingly - training_step_output_for_epoch_end = training_step_output - is_result_obj = isinstance(training_step_output, Result) - - # track batch size for weighted average - if is_result_obj: - training_step_output.track_batch_size(len(split_batch)) - - # don't allow EvalResult in the training_step - if isinstance(training_step_output, EvalResult): - raise MisconfigurationException('training_step cannot return EvalResult, ' - 'use a dict or TrainResult instead') - - # handle regular dicts - if not is_result_obj: - training_step_output = self.process_output(training_step_output, train=True) - - training_step_output = AttributeDict( - batch_loss=training_step_output[0], - pbar_on_batch_end=training_step_output[1], - log_metrics=training_step_output[2], - callback_metrics=training_step_output[3], - hiddens=training_step_output[4], - ) - - # if the user decides to finally reduce things in epoch_end, save raw output without graphs - if isinstance(training_step_output_for_epoch_end, torch.Tensor): - training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() - elif is_result_obj: - training_step_output_for_epoch_end = copy(training_step_output) - training_step_output_for_epoch_end.detach() - else: - training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end) - - # accumulate loss - # (if accumulate_grad_batches = 1 no effect) - closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss - closure_loss = closure_loss / self.accumulate_grad_batches - - # the loss will get scaled for amp. avoid any modifications to it - untouched_loss = closure_loss.detach().clone() + # lightning module hook + result = self.train_loop.training_step(split_batch, batch_idx, opt_idx, hiddens) # backward pass - with self.profiler.profile('model_backward'): - closure_loss = self.accelerator_backend.backward(closure_loss, optimizer, opt_idx) + self.train_loop.backward(result, optimizer, opt_idx) # hook - self.train_loop.on_after_backward(training_step_output, batch_idx, untouched_loss) + self.train_loop.on_after_backward(result.training_step_output, batch_idx, result.loss) - # result - result = AttributeDict( - loss=untouched_loss, - training_step_output=training_step_output, - training_step_output_for_epoch_end=training_step_output_for_epoch_end, - hiddens=training_step_output.hiddens, - ) return result def build_train_args(self, batch, batch_idx, opt_idx, hiddens): diff --git a/pytorch_lightning/trainer/training_loop_temp.py b/pytorch_lightning/trainer/training_loop_temp.py index 5164a4395c625..5591d135ee180 100644 --- a/pytorch_lightning/trainer/training_loop_temp.py +++ b/pytorch_lightning/trainer/training_loop_temp.py @@ -5,8 +5,12 @@ from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.trainer.supporters import Accumulator from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.core.step_result import Result from pytorch_lightning import _logger as log +from pytorch_lightning.utilities.memory import recursive_detach +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.core.step_result import EvalResult, Result +from pytorch_lightning.utilities.parsing import AttributeDict +from copy import copy class TrainLoop: @@ -130,6 +134,11 @@ def get_optimizers_iterable(self): opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) return [(opt_idx, self.trainer.optimizers[opt_idx])] + def backward(self, result, optimizer, opt_idx): + # backward pass + with self.trainer.profiler.profile('model_backward'): + result.closure_loss = self.trainer.accelerator_backend.backward(result.closure_loss, optimizer, opt_idx) + def on_after_backward(self, training_step_output, batch_idx, untouched_loss): is_result_obj = isinstance(training_step_output, Result) @@ -143,3 +152,64 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss): # when in dev debugging track the losses self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach()) + + def training_step(self, split_batch, batch_idx, opt_idx, hiddens): + with self.trainer.profiler.profile('model_forward'): + args = self.trainer.build_train_args(split_batch, batch_idx, opt_idx, hiddens) + training_step_output = self.trainer.accelerator_backend.training_step(args) + training_step_output = self.trainer.call_hook('training_step_end', training_step_output) + + # ---------------------------- + # PROCESS THE RESULT + # ---------------------------- + # format and reduce outputs accordingly + training_step_output_for_epoch_end = training_step_output + is_result_obj = isinstance(training_step_output, Result) + + # track batch size for weighted average + if is_result_obj: + training_step_output.track_batch_size(len(split_batch)) + + # don't allow EvalResult in the training_step + if isinstance(training_step_output, EvalResult): + raise MisconfigurationException('training_step cannot return EvalResult, ' + 'use a dict or TrainResult instead') + + # handle regular dicts + if not is_result_obj: + training_step_output = self.trainer.process_output(training_step_output, train=True) + + training_step_output = AttributeDict( + batch_loss=training_step_output[0], + pbar_on_batch_end=training_step_output[1], + log_metrics=training_step_output[2], + callback_metrics=training_step_output[3], + hiddens=training_step_output[4], + ) + + # if the user decides to finally reduce things in epoch_end, save raw output without graphs + if isinstance(training_step_output_for_epoch_end, torch.Tensor): + training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() + elif is_result_obj: + training_step_output_for_epoch_end = copy(training_step_output) + training_step_output_for_epoch_end.detach() + else: + training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end) + + # accumulate loss + # (if accumulate_grad_batches = 1 no effect) + closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss + closure_loss = closure_loss / self.trainer.accumulate_grad_batches + + # the loss will get scaled for amp. avoid any modifications to it + untouched_loss = closure_loss.detach().clone() + + # result + result = AttributeDict( + closure_loss=closure_loss, + loss=untouched_loss, + training_step_output=training_step_output, + training_step_output_for_epoch_end=training_step_output_for_epoch_end, + hiddens=training_step_output.hiddens, + ) + return result