diff --git a/pytorch_lightning/trainer/data_connector.py b/pytorch_lightning/trainer/data_connector.py index 5b7556a4bc07a..226dd965eaebc 100644 --- a/pytorch_lightning/trainer/data_connector.py +++ b/pytorch_lightning/trainer/data_connector.py @@ -24,6 +24,25 @@ class DataConnector(object): def __init__(self, trainer): self.trainer = trainer + def get_profiled_train_dataloader(self, train_dataloader): + profiled_dl = self.trainer.profiler.profile_iterable( + enumerate(self._with_is_last(train_dataloader)), + "get_train_batch" + ) + return profiled_dl + + def _with_is_last(self, iterable): + """Pass through values from the given iterable with an added boolean indicating if this is the last item. + See `https://stackoverflow.com/a/1630350 `_""" + it = iter(iterable) + last = next(it) + for val in it: + # yield last and has next + yield last, False + last = val + # yield last, no longer has next + yield last, True + def prepare_data(self, model): # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 # or in the case where each node needs to do its own manipulation in which case just local_rank=0 diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 0524dfe2bde5c..8a941f35afb06 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -183,6 +183,7 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.trainer.training_loop_temp import TrainLoop +from pytorch_lightning.trainer.data_connector import DataConnector try: from apex import amp @@ -264,6 +265,7 @@ class TrainerTrainLoopMixin(ABC): accelerator_backend: ... val_dataloaders: ... train_loop: TrainLoop + data_connector: DataConnector # Callback system callbacks: List[Callback] @@ -443,10 +445,10 @@ def run_training_epoch(self): # track epoch output epoch_output = [[] for _ in range(self.train_loop.num_optimizers)] - # run epoch - for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( - enumerate(_with_is_last(train_dataloader)), "get_train_batch" - ): + # enable profiling for the dataloader + train_dataloader = self.data_connector.get_profiled_train_dataloader(train_dataloader) + dataloader_idx = 0 + for batch_idx, (batch, is_last_batch) in train_dataloader: # stop epoch if we limited the number of training batches if batch_idx >= self.num_training_batches: break @@ -457,7 +459,7 @@ def run_training_epoch(self): # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ - batch_output = self.run_training_batch(batch, batch_idx) + batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory @@ -467,12 +469,8 @@ def run_training_epoch(self): self.train_loop.checkpoint_accumulator ) - # track the outputs to reduce at the end of the epoch - for opt_idx, opt_outputs in enumerate(epoch_end_outputs): - # with 1 step (no tbptt) don't use a sequence at epoch end - if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result): - opt_outputs = opt_outputs[0] - epoch_output[opt_idx].append(opt_outputs) + # hook + self.train_loop.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx) # when returning -1 from train_step, we end epoch early self.should_stop = batch_output.signal == -1 @@ -748,7 +746,7 @@ def should_check_val(self, batch_idx, is_last_batch): return should_check_val - def run_training_batch(self, batch, batch_idx): + def run_training_batch(self, batch, batch_idx, dataloader_idx): # track grad norms grad_norm_dic = {} @@ -767,7 +765,6 @@ def run_training_batch(self, batch, batch_idx): return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) # hook - dataloader_idx = 0 response = self.call_hook('on_batch_start') if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) @@ -859,12 +856,6 @@ def run_training_batch(self, batch, batch_idx): # reset for next set of accumulated grads self.batch_loss_value.reset() - # hook - self.call_hook('on_batch_end') - - # hook - self.call_hook('on_train_batch_end', batch, batch_idx, dataloader_idx) - # collapse all metrics into one dict batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()} @@ -1186,16 +1177,3 @@ def update_learning_rates(self, interval: str, monitor_metrics=None): scheduler_idx, old_lr, new_lr ) - - -def _with_is_last(iterable): - """Pass through values from the given iterable with an added boolean indicating if this is the last item. - See `https://stackoverflow.com/a/1630350 `_""" - it = iter(iterable) - last = next(it) - for val in it: - # yield last and has next - yield last, False - last = val - # yield last, no longer has next - yield last, True diff --git a/pytorch_lightning/trainer/training_loop_temp.py b/pytorch_lightning/trainer/training_loop_temp.py index 6df9028d85b7a..1695b53b3fa80 100644 --- a/pytorch_lightning/trainer/training_loop_temp.py +++ b/pytorch_lightning/trainer/training_loop_temp.py @@ -1,5 +1,6 @@ from pytorch_lightning.trainer.supporters import Accumulator import numpy as np +from pytorch_lightning.core.step_result import Result class TrainLoop: @@ -27,6 +28,23 @@ def on_train_epoch_start(self): self.early_stopping_accumulator = Accumulator() self.checkpoint_accumulator = Accumulator() + def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx): + # figure out what to track for epoch end + self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs) + + # hook + self.trainer.call_hook('on_batch_end') + self.trainer.call_hook('on_train_batch_end', batch, batch_idx, dataloader_idx) + + def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs): + # track the outputs to reduce at the end of the epoch + for opt_idx, opt_outputs in enumerate(epoch_end_outputs): + # with 1 step (no tbptt) don't use a sequence at epoch end + if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result): + opt_outputs = opt_outputs[0] + epoch_output[opt_idx].append(opt_outputs) + + def get_optimizers_iterable(self): """ Generates an iterable with (idx, optimizer) for each optimizer. diff --git a/tests/trainer/test_trainer_steps_dict_return.py b/tests/trainer/test_trainer_steps_dict_return.py index 1fca3159c96c5..db4d253b95422 100644 --- a/tests/trainer/test_trainer_steps_dict_return.py +++ b/tests/trainer/test_trainer_steps_dict_return.py @@ -30,7 +30,7 @@ def test_training_step_dict(tmpdir): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.run_training_batch(batch, batch_idx) + out = trainer.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 assert out.batch_log_metrics['log_acc1'] == 12.0 assert out.batch_log_metrics['log_acc2'] == 7.0 @@ -76,7 +76,7 @@ def training_step_with_step_end(tmpdir): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.run_training_batch(batch, batch_idx) + out = trainer.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 assert out.batch_log_metrics['log_acc1'] == 14.0 assert out.batch_log_metrics['log_acc2'] == 9.0 @@ -117,7 +117,7 @@ def test_full_training_loop_dict(tmpdir): # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.run_training_batch(batch, batch_idx) + out = trainer.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 assert out.batch_log_metrics['log_acc1'] == 14.0 assert out.batch_log_metrics['log_acc2'] == 9.0 @@ -204,7 +204,7 @@ def test_train_step_epoch_end(tmpdir): # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.run_training_batch(batch, batch_idx) + out = trainer.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 assert out.batch_log_metrics['log_acc1'] == 12.0 assert out.batch_log_metrics['log_acc2'] == 7.0 diff --git a/tests/trainer/test_trainer_steps_result_return.py b/tests/trainer/test_trainer_steps_result_return.py index 291e45f9dfe54..ed153a7b721d8 100644 --- a/tests/trainer/test_trainer_steps_result_return.py +++ b/tests/trainer/test_trainer_steps_result_return.py @@ -69,7 +69,7 @@ def test_training_step_result_log_step_only(tmpdir): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.run_training_batch(batch, batch_idx) + out = trainer.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 assert out.batch_log_metrics[f'step_log_and_pbar_acc1_b{batch_idx}'] == 11.0 assert out.batch_log_metrics[f'step_log_acc2_b{batch_idx}'] == 12.0 @@ -144,7 +144,7 @@ def test_training_step_result_log_epoch_only(tmpdir): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.run_training_batch(batch, batch_idx) + out = trainer.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 assert len(out.batch_log_metrics) == 0 @@ -277,7 +277,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.run_training_batch(batch, batch_idx) + out = trainer.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 assert len(out.batch_log_metrics) == 2 @@ -356,7 +356,7 @@ def test_training_step_epoch_end_result(tmpdir): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.run_training_batch(batch, batch_idx) + out = trainer.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 assert len(out.batch_log_metrics) == 2 diff --git a/tests/trainer/test_trainer_steps_scalar_return.py b/tests/trainer/test_trainer_steps_scalar_return.py index 40c716ac477df..23addf3a7731e 100644 --- a/tests/trainer/test_trainer_steps_scalar_return.py +++ b/tests/trainer/test_trainer_steps_scalar_return.py @@ -31,7 +31,7 @@ def test_training_step_scalar(tmpdir): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.run_training_batch(batch, batch_idx) + out = trainer.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) @@ -68,7 +68,7 @@ def training_step_scalar_with_step_end(tmpdir): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.run_training_batch(batch, batch_idx) + out = trainer.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) @@ -115,7 +115,7 @@ def test_full_training_loop_scalar(tmpdir): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.run_training_batch(batch, batch_idx) + out = trainer.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) @@ -158,7 +158,7 @@ def test_train_step_epoch_end_scalar(tmpdir): for batch_idx, batch in enumerate(model.train_dataloader()): break - out = trainer.run_training_batch(batch, batch_idx) + out = trainer.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)