Skip to content

Commit

Permalink
Trainer.test should return only test metrics (#5214)
Browse files Browse the repository at this point in the history
* resolve bug

* merge tests

(cherry picked from commit 9ebbfec)
  • Loading branch information
tchaton authored and Borda committed Jan 6, 2021
1 parent 21fe640 commit 5ad13dc
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def update_logger_connector(self) -> None:
callback_metrics.update(epoch_log_metrics)
callback_metrics.update(forked_metrics)

if not is_train:
if not is_train and self.trainer.testing:
logger_connector.evaluation_callback_metrics.update(callback_metrics)

# update callback_metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from copy import deepcopy
import os
from pprint import pprint
from typing import Iterable, Union

Expand Down Expand Up @@ -276,10 +276,13 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
if isinstance(eval_results, list):
for eval_result in eval_results:
self.trainer.logger_connector.callback_metrics.update(eval_result.callback_metrics)
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_result.callback_metrics)
if self.trainer.testing:
self.trainer.logger_connector.evaluation_callback_metrics.update(
eval_result.callback_metrics)
else:
self.trainer.logger_connector.callback_metrics.update(eval_results.callback_metrics)
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_results.callback_metrics)
if self.trainer.testing:
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_results.callback_metrics)
else:
flat = {}
if isinstance(eval_results, list):
Expand All @@ -295,7 +298,8 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
flat['checkpoint_on'] = flat['val_loss']
flat['early_stop_on'] = flat['val_loss']
self.trainer.logger_connector.callback_metrics.update(flat)
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
if self.trainer.testing:
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
else:
# with a scalar return, auto set it to "val_loss" for callbacks
if isinstance(eval_results, torch.Tensor):
Expand All @@ -308,7 +312,8 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
flat['checkpoint_on'] = flat['val_loss']
flat['early_stop_on'] = flat['val_loss']
self.trainer.logger_connector.callback_metrics.update(flat)
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
if self.trainer.testing:
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)

def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics):
# eval loop returns all metrics
Expand All @@ -325,7 +330,8 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric
callback_metrics.update(log_metrics)
callback_metrics.update(prog_bar_metrics)
self.trainer.logger_connector.callback_metrics.update(callback_metrics)
self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics)
if self.trainer.testing:
self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics)

if len(dataloader_result_metrics) > 0:
self.eval_loop_results.append(dataloader_result_metrics)
Expand Down
48 changes: 15 additions & 33 deletions tests/trainer/logging_tests/test_eval_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import pytest
import torch

from pytorch_lightning import Trainer, callbacks, seed_everything
from pytorch_lightning import callbacks, seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import TensorBoardLogger
Expand Down Expand Up @@ -812,7 +812,7 @@ def validation_step(self, batch, batch_idx):
def test_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log('fake_test_acc', loss)
self.log('test_loss', loss)
return {"y": loss}

model = ExtendedModel()
Expand All @@ -824,7 +824,7 @@ def test_step(self, batch, batch_idx):
logger=TensorBoardLogger(tmpdir),
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=0,
limit_test_batches=2,
max_epochs=2,
progress_bar_refresh_rate=1,
)
Expand Down Expand Up @@ -876,33 +876,15 @@ def get_metrics_at_idx(idx):
expected = torch.stack(model.val_losses[4:]).mean()
assert get_metrics_at_idx(6)["valid_loss_1"] == expected


def test_progress_bar_dict_contains_values_on_test_epoch_end(tmpdir):
class TestModel(BoringModel):
def test_step(self, *args):
self.log("foo", torch.tensor(self.current_epoch), on_step=False, on_epoch=True, prog_bar=True)

def test_epoch_end(self, *_):
self.epoch_end_called = True
self.log('foo_2', torch.tensor(self.current_epoch), prog_bar=True,
on_epoch=True, sync_dist=True, sync_dist_op='sum')

def on_test_epoch_end(self, *_):
self.on_test_epoch_end_called = True
assert self.trainer.progress_bar_dict["foo"] == self.current_epoch
assert self.trainer.progress_bar_dict["foo_2"] == self.current_epoch

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_train_batches=1,
num_sanity_val_steps=2,
checkpoint_callback=False,
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
model = TestModel()
trainer.test(model)
assert model.epoch_end_called
assert model.on_test_epoch_end_called
results = trainer.test(model)
expected_callback_metrics = {
'train_loss',
'valid_loss_0_epoch',
'valid_loss_0',
'debug_epoch',
'valid_loss_1',
'test_loss',
'val_loss'
}
assert set(trainer.callback_metrics) == expected_callback_metrics
assert set(results[0]) == {'test_loss', 'debug_epoch'}

0 comments on commit 5ad13dc

Please sign in to comment.