Skip to content

Commit

Permalink
Write predictions in LightningModule instead of EvalResult (#3882)
Browse files Browse the repository at this point in the history
* ✨ add self.write_prediction

* ✨ add self.write_prediction_dict to lightning module
  • Loading branch information
nateraw authored Oct 5, 2020
1 parent cea5f1f commit 1954d7c
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
7 changes: 7 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,13 @@ def log_dict(
tbptt_reduce_fx=tbptt_reduce_fx,
)

def write_prediction(self, name, value, filename='predictions.pt'):
self.trainer.evaluation_loop.predictions._add_prediction(name, value, filename)

def write_prediction_dict(self, predictions_dict, filename='predictions.pt'):
for k, v in predictions_dict.items():
self.write_prediction(k, v, filename)

def __auto_choose_log_on_step(self, on_step):
if on_step is None:
if self._current_fx_name in {'training_step', 'training_step_end'}:
Expand Down
33 changes: 18 additions & 15 deletions tests/base/model_test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,38 +151,41 @@ def test_step_result_preds(self, batch, batch_idx, optimizer_idx=None):

# Base
if option == 0:
result.write('idxs', lazy_ids, prediction_file)
result.write('preds', labels_hat, prediction_file)
self.write_prediction('idxs', lazy_ids, prediction_file)
self.write_prediction('preds', labels_hat, prediction_file)

# Check mismatching tensor len
elif option == 1:
result.write('idxs', torch.cat((lazy_ids, lazy_ids)), prediction_file)
result.write('preds', labels_hat, prediction_file)
self.write_prediction('idxs', torch.cat((lazy_ids, lazy_ids)), prediction_file)
self.write_prediction('preds', labels_hat, prediction_file)

# write multi-dimension
elif option == 2:
result.write('idxs', lazy_ids, prediction_file)
result.write('preds', labels_hat, prediction_file)
result.write('x', x, prediction_file)
self.write_prediction('idxs', lazy_ids, prediction_file)
self.write_prediction('preds', labels_hat, prediction_file)
self.write_prediction('x', x, prediction_file)

# write str list
elif option == 3:
result.write('idxs', lazy_ids, prediction_file)
result.write('vals', lst_of_str, prediction_file)
self.write_prediction('idxs', lazy_ids, prediction_file)
self.write_prediction('vals', lst_of_str, prediction_file)

# write int list
elif option == 4:
result.write('idxs', lazy_ids, prediction_file)
result.write('vals', lst_of_int, prediction_file)
self.write_prediction('idxs', lazy_ids, prediction_file)
self.write_prediction('vals', lst_of_int, prediction_file)

# write nested list
elif option == 5:
result.write('idxs', lazy_ids, prediction_file)
result.write('vals', lst_of_lst, prediction_file)
self.write_prediction('idxs', lazy_ids, prediction_file)
self.write_prediction('vals', lst_of_lst, prediction_file)

# write dict list
elif option == 6:
result.write('idxs', lazy_ids, prediction_file)
result.write('vals', lst_of_dict, prediction_file)
self.write_prediction('idxs', lazy_ids, prediction_file)
self.write_prediction('vals', lst_of_dict, prediction_file)

elif option == 7:
self.write_prediction_dict({'idxs': lazy_ids, 'preds': labels_hat}, prediction_file)

return result
3 changes: 3 additions & 0 deletions tests/core/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def test_result_reduce_ddp(result_cls):
pytest.param(
6, False, 0, id='dict_list_predictions'
),
pytest.param(
7, True, 0, id='write_dict_predictions'
),
pytest.param(
0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires single-GPU machine")
)
Expand Down

0 comments on commit 1954d7c

Please sign in to comment.