From 1954d7c87a670ab23e289ed1ef3500140c1d89c2 Mon Sep 17 00:00:00 2001 From: Nathan Raw Date: Mon, 5 Oct 2020 16:04:02 -0600 Subject: [PATCH] Write predictions in LightningModule instead of EvalResult (#3882) * :sparkles: add self.write_prediction * :sparkles: add self.write_prediction_dict to lightning module --- pytorch_lightning/core/lightning.py | 7 ++++++ tests/base/model_test_steps.py | 33 ++++++++++++++++------------- tests/core/test_results.py | 3 +++ 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 77704b1ff12f0..6a540bba7022e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -304,6 +304,13 @@ def log_dict( tbptt_reduce_fx=tbptt_reduce_fx, ) + def write_prediction(self, name, value, filename='predictions.pt'): + self.trainer.evaluation_loop.predictions._add_prediction(name, value, filename) + + def write_prediction_dict(self, predictions_dict, filename='predictions.pt'): + for k, v in predictions_dict.items(): + self.write_prediction(k, v, filename) + def __auto_choose_log_on_step(self, on_step): if on_step is None: if self._current_fx_name in {'training_step', 'training_step_end'}: diff --git a/tests/base/model_test_steps.py b/tests/base/model_test_steps.py index 92b1d68e675a1..da49327a75ce2 100644 --- a/tests/base/model_test_steps.py +++ b/tests/base/model_test_steps.py @@ -151,38 +151,41 @@ def test_step_result_preds(self, batch, batch_idx, optimizer_idx=None): # Base if option == 0: - result.write('idxs', lazy_ids, prediction_file) - result.write('preds', labels_hat, prediction_file) + self.write_prediction('idxs', lazy_ids, prediction_file) + self.write_prediction('preds', labels_hat, prediction_file) # Check mismatching tensor len elif option == 1: - result.write('idxs', torch.cat((lazy_ids, lazy_ids)), prediction_file) - result.write('preds', labels_hat, prediction_file) + self.write_prediction('idxs', torch.cat((lazy_ids, lazy_ids)), prediction_file) + self.write_prediction('preds', labels_hat, prediction_file) # write multi-dimension elif option == 2: - result.write('idxs', lazy_ids, prediction_file) - result.write('preds', labels_hat, prediction_file) - result.write('x', x, prediction_file) + self.write_prediction('idxs', lazy_ids, prediction_file) + self.write_prediction('preds', labels_hat, prediction_file) + self.write_prediction('x', x, prediction_file) # write str list elif option == 3: - result.write('idxs', lazy_ids, prediction_file) - result.write('vals', lst_of_str, prediction_file) + self.write_prediction('idxs', lazy_ids, prediction_file) + self.write_prediction('vals', lst_of_str, prediction_file) # write int list elif option == 4: - result.write('idxs', lazy_ids, prediction_file) - result.write('vals', lst_of_int, prediction_file) + self.write_prediction('idxs', lazy_ids, prediction_file) + self.write_prediction('vals', lst_of_int, prediction_file) # write nested list elif option == 5: - result.write('idxs', lazy_ids, prediction_file) - result.write('vals', lst_of_lst, prediction_file) + self.write_prediction('idxs', lazy_ids, prediction_file) + self.write_prediction('vals', lst_of_lst, prediction_file) # write dict list elif option == 6: - result.write('idxs', lazy_ids, prediction_file) - result.write('vals', lst_of_dict, prediction_file) + self.write_prediction('idxs', lazy_ids, prediction_file) + self.write_prediction('vals', lst_of_dict, prediction_file) + + elif option == 7: + self.write_prediction_dict({'idxs': lazy_ids, 'preds': labels_hat}, prediction_file) return result diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 48142a0e95a61..5aede2c42ca32 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -70,6 +70,9 @@ def test_result_reduce_ddp(result_cls): pytest.param( 6, False, 0, id='dict_list_predictions' ), + pytest.param( + 7, True, 0, id='write_dict_predictions' + ), pytest.param( 0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires single-GPU machine") )