diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 70d76f683bd07..635355f0bd197 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -739,10 +739,12 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): # track metrics to log batch_log_metrics = [] + # bookkeeping using_results_obj = False + self.hiddens = None # track all outputs across time and num of optimizers - batch_outputs = [[] for i in range(len(self.train_loop.get_optimizers_iterable()))] + batch_outputs = [[] for _ in range(len(self.train_loop.get_optimizers_iterable()))] if batch is None: return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) @@ -757,16 +759,13 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) - splits = [batch] - if self.truncated_bptt_steps is not None: - model_ref = self.get_model() - with self.profiler.profile('tbptt_split_batch'): - splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps) + # lightning module hook + splits = self.train_loop.tbptt_split_batch(batch) - self.hiddens = None for split_idx, split_batch in enumerate(splits): self.split_idx = split_idx + # loop over optimizers for opt_idx, optimizer in self.train_loop.get_optimizers_iterable(): # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. @@ -780,7 +779,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): # ------------------- # calculate loss (train step + train step end) # ------------------- - opt_closure_result = self.optimizer_closure( + opt_closure_result = self.training_step_and_backward( split_batch, batch_idx, opt_idx, @@ -808,13 +807,19 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): # BACKWARD PASS # ------------------------------ # gradient update with accumulated gradients - if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0 - or (self.batch_idx + 1) == self.num_training_batches): + accumulation_done = (self.batch_idx + 1) % self.accumulate_grad_batches == 0 + is_final_batch = (self.batch_idx + 1) == self.num_training_batches + if accumulation_done or is_final_batch: # hook grad_norm_dic = self.train_loop.on_before_backward(batch_idx, optimizer) + # wrap forward + backward pass in closure for 2nd order optimizers + train_step_and_backward_closure = lambda: self.training_step_and_backward( + split_batch, batch_idx, opt_idx, optimizer, self.hiddens, + ).loss + # optimizer step - self.train_loop.optimizer_step(optimizer, opt_idx, batch_idx, split_batch) + self.train_loop.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) # hook self.train_loop.on_before_zero_grad(optimizer) @@ -843,7 +848,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): ) return result - def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): + def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): """ wrap the forward step in a closure so second order methods work """ diff --git a/pytorch_lightning/trainer/training_loop_temp.py b/pytorch_lightning/trainer/training_loop_temp.py index 89f87e137e9b7..82e14ca3d81e1 100644 --- a/pytorch_lightning/trainer/training_loop_temp.py +++ b/pytorch_lightning/trainer/training_loop_temp.py @@ -214,21 +214,11 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens): ) return result - def optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch): - # calls .step(), .zero_grad() - # override function to modify this behavior - + def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): with self.trainer.profiler.profile('optimizer_step'): - lambda_closure = lambda: self.trainer.optimizer_closure( - split_batch, - batch_idx, - opt_idx, - optimizer, - self.trainer.hiddens, - ).loss - # optimizer step lightningModule hook - self.trainer.accelerator_backend.optimizer_step(optimizer, batch_idx, opt_idx, lambda_closure) + self.trainer.accelerator_backend.optimizer_step(optimizer, batch_idx, opt_idx, + train_step_and_backward_closure) def on_before_zero_grad(self, optimizer): model = self.trainer.get_model() @@ -280,3 +270,11 @@ def process_hiddens(self, opt_closure_result): if isinstance(opt_closure_result.training_step_output, Result): opt_closure_result.training_step_output_for_epoch_end.drop_hiddens() return hiddens + + def tbptt_split_batch(self, batch): + splits = [batch] + if self.trainer.truncated_bptt_steps is not None: + model_ref = self.trainer.get_model() + with self.trainer.profiler.profile('tbptt_split_batch'): + splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps) + return splits diff --git a/tests/trainer/test_trainer_steps_dict_return.py b/tests/trainer/test_trainer_steps_dict_return.py index db4d253b95422..868f4d7a1d6dc 100644 --- a/tests/trainer/test_trainer_steps_dict_return.py +++ b/tests/trainer/test_trainer_steps_dict_return.py @@ -47,7 +47,7 @@ def test_training_step_dict(tmpdir): assert pbar_metrics['pbar_acc2'] == 19.0 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3) diff --git a/tests/trainer/test_trainer_steps_result_return.py b/tests/trainer/test_trainer_steps_result_return.py index ed153a7b721d8..a120a1e7a3332 100644 --- a/tests/trainer/test_trainer_steps_result_return.py +++ b/tests/trainer/test_trainer_steps_result_return.py @@ -84,7 +84,7 @@ def test_training_step_result_log_step_only(tmpdir): assert f'step_log_acc2_b{batch_idx}' in train_step_out # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3) @@ -158,7 +158,7 @@ def test_training_step_result_log_epoch_only(tmpdir): assert f'epoch_log_acc2_e{trainer.current_epoch}' in train_step_out # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3) @@ -293,7 +293,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir): assert 'epoch_step_epoch_log_acc2' in train_step_out # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3) @@ -372,7 +372,7 @@ def test_training_step_epoch_end_result(tmpdir): assert 'epoch_step_epoch_log_acc2' in train_step_out # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3) diff --git a/tests/trainer/test_trainer_steps_scalar_return.py b/tests/trainer/test_trainer_steps_scalar_return.py index 23addf3a7731e..baf09b22f5acf 100644 --- a/tests/trainer/test_trainer_steps_scalar_return.py +++ b/tests/trainer/test_trainer_steps_scalar_return.py @@ -43,7 +43,7 @@ def test_training_step_scalar(tmpdir): assert train_step_out.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) assert opt_closure_result['loss'].item() == 171 @@ -80,7 +80,7 @@ def training_step_scalar_with_step_end(tmpdir): assert train_step_out.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) assert opt_closure_result['loss'].item() == 171 @@ -127,7 +127,7 @@ def test_full_training_loop_scalar(tmpdir): assert train_step_out.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) assert opt_closure_result['loss'].item() == 171 @@ -170,5 +170,5 @@ def test_train_step_epoch_end_scalar(tmpdir): assert train_step_out.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) assert opt_closure_result['loss'].item() == 171