Skip to content

Commit bc2c2db

Browse files
authored
Do not override the logged epoch in logged_metrics (#7982)
1 parent 2134216 commit bc2c2db

File tree

3 files changed

+38
-23
lines changed

3 files changed

+38
-23
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
164164

165165

166166
- Changed `WandbLogger(log_model={True/'all'})` to log models as artifacts ([#6231](https://github.com/PyTorchLightning/pytorch-lightning/pull/6231))
167+
168+
167169
- MLFlowLogger now accepts `run_name` as an constructor argument ([#7622](https://github.com/PyTorchLightning/pytorch-lightning/issues/7622))
168170

169171

@@ -255,6 +257,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
255257
- Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924))
256258

257259

260+
- Do not override the existing `epoch` value in `logged_metrics` when already logged by the user ([#7982](https://github.com/PyTorchLightning/pytorch-lightning/issues/7982))
261+
262+
258263
- Support manual optimization with DeepSpeed ([#7970](https://github.com/PyTorchLightning/pytorch-lightning/pull/7970))
259264

260265

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) -
9191
step: Step for which metrics should be logged. Default value is `self.global_step` during training or
9292
the total validation / test log step count during validation and testing.
9393
"""
94+
if self.trainer.logger is None or not metrics:
95+
return
96+
9497
# add gpu memory
9598
if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory:
9699
mem_map = memory.get_memory_profile(self.log_gpu_memory)
@@ -99,21 +102,19 @@ def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) -
99102
# turn all tensors to scalars
100103
scalar_metrics = metrics_to_scalars(metrics)
101104

102-
if "step" in scalar_metrics and step is None:
103-
step = scalar_metrics.pop("step")
104-
105-
elif step is None:
106-
# added metrics by Lightning for convenience
107-
scalar_metrics['epoch'] = self.trainer.current_epoch
105+
if step is None:
106+
step = scalar_metrics.pop("step", None)
107+
if step is None:
108+
# added metrics for convenience
109+
scalar_metrics.setdefault("epoch", self.trainer.current_epoch)
108110
step = self.trainer.global_step
109111

110112
# log actual metrics
111-
if self.trainer.logger is not None:
112-
if self.trainer.is_global_zero:
113-
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
114-
self.trainer.logger.save()
113+
if self.trainer.is_global_zero:
114+
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
115+
self.trainer.logger.save()
115116

116-
self._logged_metrics.update(scalar_metrics)
117+
self._logged_metrics.update(scalar_metrics)
117118

118119
"""
119120
Evaluation metric updates
@@ -149,9 +150,7 @@ def update_eval_step_metrics(self) -> None:
149150

150151
# logs user requested information to logger
151152
assert not self._epoch_end_reached
152-
metrics = self.metrics[MetricSource.LOG]
153-
if metrics:
154-
self.log_metrics(metrics, step=self._eval_log_step)
153+
self.log_metrics(self.metrics[MetricSource.LOG], step=self._eval_log_step)
155154

156155
# increment the step even if nothing was logged
157156
self._increment_eval_log_step()
@@ -179,9 +178,7 @@ def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT:
179178

180179
if not self.trainer.sanity_checking:
181180
# log all the metrics as a single dict
182-
log_metrics = metrics[MetricSource.LOG]
183-
if log_metrics:
184-
self.log_metrics(log_metrics)
181+
self.log_metrics(metrics[MetricSource.LOG])
185182

186183
self._prepare_eval_loop_results(metrics[MetricSource.CALLBACK])
187184

@@ -219,16 +216,13 @@ def update_train_step_metrics(self) -> None:
219216

220217
# when metrics should be logged
221218
assert not self._epoch_end_reached
222-
metrics = self.metrics[MetricSource.LOG]
223-
if self.should_update_logs or self.trainer.fast_dev_run is True and metrics:
224-
self.log_metrics(metrics)
219+
if self.should_update_logs or self.trainer.fast_dev_run:
220+
self.log_metrics(self.metrics[MetricSource.LOG])
225221

226222
def update_train_epoch_metrics(self) -> None:
227223
# add the metrics to the loggers
228224
assert self._epoch_end_reached
229-
metrics = self.metrics[MetricSource.LOG]
230-
if metrics:
231-
self.log_metrics(metrics)
225+
self.log_metrics(self.metrics[MetricSource.LOG])
232226

233227
# reset result collection for next epoch
234228
self.trainer._results.reset(metrics=True)

tests/trainer/logging_/test_logger_connector.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,3 +492,19 @@ def test_result_collection_on_tensor_with_mean_reduction():
492492
'loss_on_step_on_epoch_prog_bar_logger': mean,
493493
'loss_on_step_on_epoch_prog_bar_logger_epoch': mean
494494
}
495+
496+
497+
def test_logged_metrics_has_logged_epoch_value(tmpdir):
498+
499+
class TestModel(BoringModel):
500+
501+
def training_step(self, batch, batch_idx):
502+
self.log('epoch', -batch_idx, logger=True)
503+
return super().training_step(batch, batch_idx)
504+
505+
model = TestModel()
506+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
507+
trainer.fit(model)
508+
509+
# should not get overridden if logged manually
510+
assert trainer.logged_metrics == {'epoch': -1}

0 commit comments

Comments
 (0)