Skip to content

Commit

Permalink
ref: (2/n) fix no log in epoch end (#3699)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored Sep 28, 2020
1 parent 859ec92 commit 2ecaa2a
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 17 deletions.
35 changes: 18 additions & 17 deletions pytorch_lightning/trainer/connectors/logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,30 +272,31 @@ def log_train_epoch_end_metrics(self,
self.callback_metrics.update(epoch_progress_bar_metrics)

def training_epoch_end(self, model, epoch_output, num_optimizers):
if not is_overridden('training_epoch_end', model=model):
return Result()

# run training_epoch_end
# a list with a result per optimizer index
if is_overridden('training_epoch_end', model=model):
# refresh the result for custom logging at the epoch level
model._current_fx_name = 'training_epoch_end'
model._results = Result()
# refresh the result for custom logging at the epoch level
model._current_fx_name = 'training_epoch_end'
model._results = Result()

epoch_output = self.__prepare_epoch_end_inputs(epoch_output)
epoch_output = self.__prepare_epoch_end_inputs(epoch_output)

if num_optimizers == 1:
epoch_output = epoch_output[0]
if num_optimizers == 1:
epoch_output = epoch_output[0]

# lightningmodule hook
epoch_output = model.training_epoch_end(epoch_output)
# lightningmodule hook
epoch_output = model.training_epoch_end(epoch_output)

model._current_fx_name = ''
model._current_fx_name = ''

if epoch_output is not None:
raise MisconfigurationException('training_epoch_end expects a return of None. '
'HINT: remove the return statement in training_epoch_end')
if epoch_output is not None:
raise MisconfigurationException('training_epoch_end expects a return of None. '
'HINT: remove the return statement in training_epoch_end')

# user can ALSO log at the end of an epoch
new_epoch_end_logs = model._results
return new_epoch_end_logs
# user can ALSO log at the end of an epoch
new_epoch_end_logs = model._results
return new_epoch_end_logs

def __run_legacy_training_epoch_end(
self,
Expand Down
54 changes: 54 additions & 0 deletions tests/trainer/test_trainining_step_no_dict_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,60 @@
import torch


def test_training_step_scalar_no_epoch_end_log(tmpdir):
"""
Tests that only training_step can be used
"""
os.environ['PL_DEV_DEBUG'] = '1'

class TestModel(DeterministicModel):
def training_step(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
acc = acc + batch_idx
self.log('step_acc', acc, on_step=True, on_epoch=False)
self.log('epoch_acc', acc, on_step=False, on_epoch=True)
self.log('no_prefix_step_epoch_acc', acc, on_step=True, on_epoch=True)
self.log('pbar_step_acc', acc, on_step=True, prog_bar=True, on_epoch=False, logger=False)
self.log('pbar_epoch_acc', acc, on_step=False, on_epoch=True, prog_bar=True, logger=False)
self.log('pbar_step_epoch_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=False)

self.training_step_called = True
return acc

def backward(self, trainer, loss, optimizer, optimizer_idx):
loss.backward()

model = TestModel()
model.val_dataloader = None

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
row_log_interval=1,
weights_summary=None,
)
trainer.fit(model)

# make sure correct steps were called
assert model.training_step_called
assert not model.training_step_end_called

# make sure all the metrics are available for callbacks
metrics = [
'step_acc',
'epoch_acc',
'no_prefix_step_epoch_acc', 'step_no_prefix_step_epoch_acc', 'epoch_no_prefix_step_epoch_acc',
'pbar_step_acc',
'pbar_epoch_acc',
'pbar_step_epoch_acc', 'step_pbar_step_epoch_acc', 'epoch_pbar_step_epoch_acc',
]
expected_metrics = set(metrics + ['debug_epoch'])
callback_metrics = set(trainer.callback_metrics.keys())
assert expected_metrics == callback_metrics


def test_training_step_scalar(tmpdir):
"""
Tests that only training_step can be used
Expand Down

0 comments on commit 2ecaa2a

Please sign in to comment.