Skip to content

Commit 86793ef

Browse files
committed
Address comments
1 parent 108fbd3 commit 86793ef

File tree

3 files changed

+24
-26
lines changed

3 files changed

+24
-26
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
147147
- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207))
148148

149149

150-
- Removed legacy code to include `step` dictionary returns in `callback_metrics` ([#6682](https://github.com/PyTorchLightning/pytorch-lightning/pull/6682))
150+
- Removed legacy code to include `step` dictionary returns in `callback_metrics`. Use `self.log_dict` instead. ([#6682](https://github.com/PyTorchLightning/pytorch-lightning/pull/6682))
151151

152152

153153
- Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -340,20 +340,6 @@ def _track_callback_metrics(self, eval_results):
340340
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
341341
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
342342

343-
def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics):
344-
# eval loop returns all metrics
345-
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics}
346-
347-
# add metrics to prog bar
348-
self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics)
349-
350-
# log metrics
351-
if len(log_metrics) > 0:
352-
self.trainer.logger_connector.log_metrics(log_metrics, {})
353-
354-
if len(dataloader_result_metrics) > 0:
355-
self.eval_loop_results.append(dataloader_result_metrics)
356-
357343
def __process_eval_epoch_end_results_and_log_legacy(self, eval_results):
358344
if self.trainer.sanity_checking:
359345
return
@@ -364,17 +350,21 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results):
364350
if not isinstance(eval_results, list):
365351
eval_results = [eval_results]
366352

367-
num_loaders: int = self.trainer.evaluation_loop.num_dataloaders
368-
prog_bar_metrics, log_metrics = {}, {}
369-
370353
for result_idx, result in enumerate(eval_results):
371354
_, prog_bar_metrics, log_metrics, _ = self.trainer.process_dict_result(result)
372355

373-
if num_loaders > 1:
374-
self.__process_eval_epoch_end_results_and_log_legacy_update(prog_bar_metrics, log_metrics)
356+
# eval loop returns all metrics
357+
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics}
358+
359+
# add metrics to prog bar
360+
self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics)
361+
362+
# log metrics
363+
if len(log_metrics) > 0:
364+
self.trainer.logger_connector.log_metrics(log_metrics, {})
375365

376-
if num_loaders == 1:
377-
self.__process_eval_epoch_end_results_and_log_legacy_update(prog_bar_metrics, log_metrics)
366+
if len(dataloader_result_metrics) > 0:
367+
self.eval_loop_results.append(dataloader_result_metrics)
378368

379369
def on_train_epoch_end(self):
380370
# inform cached logger connector epoch finished

tests/trainer/logging_/test_logger_connector.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,14 +487,22 @@ class TestModel(BoringModel):
487487
def validation_step(self, batch, *args, **kwargs):
488488
return {"val": torch.tensor([0, 1])}
489489

490+
def validation_epoch_end(self, outputs):
491+
# ensure validation step returns still appear here
492+
assert len(outputs) == 2
493+
assert all(list(d) == ["val"] for d in outputs) # check keys
494+
assert all(torch.equal(d["val"], torch.tensor([0, 1])) for d in outputs) # check values
495+
490496
def test_step(self, batch, *args, **kwargs):
491497
return {"test": torch.tensor([0, 1])}
492498

493-
model = TestModel()
494-
model.validation_epoch_end = None
495-
model.test_epoch_end = None
499+
def test_epoch_end(self, outputs):
500+
assert len(outputs) == 2
501+
assert all(list(d) == ["test"] for d in outputs) # check keys
502+
assert all(torch.equal(d["test"], torch.tensor([0, 1])) for d in outputs) # check values
496503

497-
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, progress_bar_refresh_rate=0)
504+
model = TestModel()
505+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2, progress_bar_refresh_rate=0)
498506
trainer.fit(model)
499507
trainer.validate(model)
500508
trainer.test(model)

0 commit comments

Comments
 (0)