Skip to content

Commit

Permalink
ref: enable self.log from val step (#3701)
Browse files Browse the repository at this point in the history
* .log in eval

* ref

* ref: enable self.log in val step
  • Loading branch information
williamFalcon authored Sep 28, 2020
1 parent 2ecaa2a commit cdd7266
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 20 deletions.
9 changes: 4 additions & 5 deletions pytorch_lightning/trainer/connectors/logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
self.trainer.logger.save()

# track the logged metrics
self.logged_metrics = scalar_metrics
self.logged_metrics.update(scalar_metrics)
self.trainer.dev_debugger.track_logged_metrics_history(scalar_metrics)

def add_progress_bar_metrics(self, metrics):
Expand Down Expand Up @@ -191,9 +191,8 @@ def __log_evaluation_epoch_metrics_2(self, eval_results, test_mode):

return eval_loop_results

def on_train_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers):
self.log_train_epoch_end_metrics(epoch_output, checkpoint_accumulator,
early_stopping_accumulator, num_optimizers)
def on_train_epoch_end(self, epoch_output):
pass

def log_train_epoch_end_metrics(self,
epoch_output,
Expand Down Expand Up @@ -413,7 +412,7 @@ def __gather_result_across_time_and_optimizers(self, epoch_output):

return gathered_epoch_outputs

def save_train_loop_metrics_to_loggers(self, batch_idx, batch_output):
def log_train_step_metrics(self, batch_idx, batch_output):
# when metrics should be logged
should_log_metrics = (batch_idx + 1) % self.trainer.row_log_interval == 0 or self.trainer.should_stop
if should_log_metrics or self.trainer.fast_dev_run:
Expand Down
34 changes: 23 additions & 11 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ def __auto_reduce_result_objs(self, outputs):
return eval_results

def on_evaluation_batch_start(self, *args, **kwargs):
# reset the result of the PL module
model = self.trainer.get_model()
model._results = Result()
model._current_fx_name = 'evaluation_step'

if self.testing:
self.trainer.call_hook('on_test_batch_start', *args, **kwargs)
else:
Expand Down Expand Up @@ -273,21 +278,28 @@ def on_evaluation_epoch_end(self, *args, **kwargs):
else:
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)

def log_step_metrics(self, output, batch_idx):
def log_evaluation_step_metrics(self, output, batch_idx):
if self.trainer.running_sanity_check:
return

results = self.trainer.get_model()._results
self.__log_result_step_metrics(results, batch_idx)

# TODO: deprecate at 1.0
if isinstance(output, EvalResult):
step_log_metrics = output.batch_log_metrics
step_pbar_metrics = output.batch_pbar_metrics
self.__log_result_step_metrics(output, batch_idx)

def __log_result_step_metrics(self, output, batch_idx):
step_log_metrics = output.batch_log_metrics
step_pbar_metrics = output.batch_pbar_metrics

if len(step_log_metrics) > 0:
# make the metrics appear as a different line in the same graph
metrics_by_epoch = {}
for k, v in step_log_metrics.items():
metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v
if len(step_log_metrics) > 0:
# make the metrics appear as a different line in the same graph
metrics_by_epoch = {}
for k, v in step_log_metrics.items():
metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v

self.trainer.logger_connector.log_metrics(metrics_by_epoch, {}, step=batch_idx)
self.trainer.logger_connector.log_metrics(metrics_by_epoch, {}, step=batch_idx)

if len(step_pbar_metrics) > 0:
self.trainer.logger_connector.add_progress_bar_metrics(step_pbar_metrics)
if len(step_pbar_metrics) > 0:
self.trainer.logger_connector.add_progress_bar_metrics(step_pbar_metrics)
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None):

# clean up
self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx)
self.evaluation_loop.log_step_metrics(output, batch_idx)
self.evaluation_loop.log_evaluation_step_metrics(output, batch_idx)

# track epoch level metrics
if output is not None:
Expand Down
10 changes: 7 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def run_training_epoch(self):
# -----------------------------------------
# SAVE METRICS TO LOGGERS
# -----------------------------------------
self.trainer.logger_connector.save_train_loop_metrics_to_loggers(batch_idx, batch_output)
self.trainer.logger_connector.log_train_step_metrics(batch_idx, batch_output)

# -----------------------------------------
# SAVE LOGGERS (ie: Tensorboard, etc...)
Expand Down Expand Up @@ -573,14 +573,17 @@ def run_training_epoch(self):
# progress global step according to grads progress
self.increment_accumulated_grad_global_step()

# process epoch outputs
self.trainer.logger_connector.on_train_epoch_end(
# log epoch metrics
self.trainer.logger_connector.log_train_epoch_end_metrics(
epoch_output,
self.checkpoint_accumulator,
self.early_stopping_accumulator,
self.num_optimizers
)

# hook
self.trainer.logger_connector.on_train_epoch_end(epoch_output)

# when no val loop is present or fast-dev-run still need to call checkpoints
self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model)))

Expand Down Expand Up @@ -704,6 +707,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()}

# track all metrics for callbacks
# TODO: is this needed?
self.trainer.logger_connector.callback_metrics.update(
{k: v for d in batch_callback_metrics for k, v in d.items() if v is not None}
)
Expand Down
56 changes: 56 additions & 0 deletions tests/trainer/test_trainining_step_no_dict_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,59 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
# epoch 1
assert trainer.dev_debugger.logged_metrics[3]['global_step'] == 2
assert trainer.dev_debugger.logged_metrics[4]['global_step'] == 3


def test_validation_step_logging(tmpdir):
"""
Tests that only training_step can be used
"""
os.environ['PL_DEV_DEBUG'] = '1'

class TestModel(DeterministicModel):
def training_step(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
acc = acc + batch_idx
self.log('train_step_acc', acc, on_step=True, on_epoch=True)
self.training_step_called = True
return acc

def validation_step(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
acc = acc + batch_idx
self.log('val_step_acc', acc, on_step=True, on_epoch=True)
self.training_step_called = True

def backward(self, trainer, loss, optimizer, optimizer_idx):
loss.backward()

model = TestModel()
model.validation_step_end = None
model.validation_epoch_end = None

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
row_log_interval=1,
weights_summary=None,
)
trainer.fit(model)

# make sure all the metrics are available for callbacks
expected_logged_metrics = {
'epoch',
'train_step_acc', 'step_train_step_acc', 'epoch_train_step_acc',
'val_step_acc/epoch_0', 'val_step_acc/epoch_1',
'step_val_step_acc/epoch_0', 'step_val_step_acc/epoch_1',
}
logged_metrics = set(trainer.logged_metrics.keys())
assert expected_logged_metrics == logged_metrics

# we don't want to enable val metrics during steps because it is not something that users should do
expected_cb_metrics = [
'train_step_acc', 'step_train_step_acc', 'epoch_train_step_acc',
]
expected_cb_metrics = set(expected_cb_metrics)
callback_metrics = set(trainer.callback_metrics.keys())
assert expected_cb_metrics == callback_metrics

0 comments on commit cdd7266

Please sign in to comment.