Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tests for single scalar return from training #2587

Merged
merged 5 commits into from
Jul 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,17 @@ def process_output(self, output, train=False):

Separates loss from logging and progress bar metrics
"""
# --------------------------
# handle single scalar only
# --------------------------
# single scalar returned from a xx_step
if isinstance(output, torch.Tensor):
progress_bar_metrics = {}
log_metrics = {}
callback_metrics = {}
hiddens = None
return output, progress_bar_metrics, log_metrics, callback_metrics, hiddens

# ---------------
# EXTRACT CALLBACK KEYS
# ---------------
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,10 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
)

# if the user decides to finally reduce things in epoch_end, save raw output without graphs
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)
if isinstance(training_step_output_for_epoch_end, torch.Tensor):
training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
else:
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)

# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
Expand Down
41 changes: 41 additions & 0 deletions tests/base/deterministic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,47 @@ def count_num_graphs(self, result, num_graphs=0):

return num_graphs

# ---------------------------
# scalar return
# ---------------------------
def training_step_scalar_return(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
self.training_step_called = True
return acc

def training_step_end_scalar(self, output):
self.training_step_end_called = True

# make sure loss has the grad
assert isinstance(output, torch.Tensor)
Copy link
Member

@Borda Borda Jul 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like a test, but why is it in DeterministicModel
well maybe rather ask why these tests are not part of the template?

assert output.grad_fn is not None

# make sure nothing else has grads
assert self.count_num_graphs({'loss': output}) == 1

assert output == 171

return output

def training_epoch_end_scalar(self, outputs):
"""
There should be an array of scalars without graphs that are all 171 (4 of them)
"""
self.training_epoch_end_called = True

if self.use_dp or self.use_ddp2:
pass
else:
# only saw 4 batches
assert len(outputs) == 4
for batch_out in outputs:
assert batch_out == 171
assert batch_out.grad_fn is None
assert isinstance(batch_out, torch.Tensor)

prototype_loss = outputs[0]
return prototype_loss

# --------------------------
# dictionary returns
# --------------------------
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
Tests to ensure that the training loop works with a dict
"""
from pytorch_lightning import Trainer
from tests.base.deterministic_model import DeterministicModel
import pytest
import torch


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_training_step_dict(tmpdir):
"""
Tests that only training_step can be used
Expand Down
165 changes: 165 additions & 0 deletions tests/trainer/test_trainer_steps_scalar_return.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""
Tests to ensure that the training loop works with a scalar
"""
from pytorch_lightning import Trainer
from tests.base.deterministic_model import DeterministicModel
import torch


def test_training_step_scalar(tmpdir):
"""
Tests that only training_step that returns a single scalar can be used
"""
model = DeterministicModel()
model.training_step = model.training_step_scalar_return
model.val_dataloader = None

trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
weights_summary=None,
)
trainer.fit(model)

# make sure correct steps were called
assert model.training_step_called
assert not model.training_step_end_called
assert not model.training_epoch_end_called

# make sure training outputs what is expected
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all these asserts bellow seems to be the same in all three functions... rather wrap into single assert block, otherwise, it takes some time to check what is the difference...

for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)

train_step_out = out.training_step_output_for_epoch_end
assert isinstance(train_step_out, torch.Tensor)
assert train_step_out.item() == 171

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171


def training_step_scalar_with_step_end(tmpdir):
"""
Checks train_step with scalar only + training_step_end
"""
model = DeterministicModel()
model.training_step = model.training_step_scalar_return
model.training_step_end = model.training_step_end_scalar
model.val_dataloader = None

trainer = Trainer(fast_dev_run=True, weights_summary=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing tmpdir

trainer.fit(model)

# make sure correct steps were called
assert model.training_step_called
assert model.training_step_end_called
assert not model.training_epoch_end_called

# make sure training outputs what is expected
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)

train_step_out = out.training_step_output_for_epoch_end
assert isinstance(train_step_out, torch.Tensor)
assert train_step_out.item() == 171

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171


def test_full_training_loop_scalar(tmpdir):
"""
Checks train_step + training_step_end + training_epoch_end
(all with scalar return from train_step)
"""
model = DeterministicModel()
model.training_step = model.training_step_scalar_return
model.training_step_end = model.training_step_end_scalar
model.training_epoch_end = model.training_epoch_end_scalar
model.val_dataloader = None

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model)

# make sure correct steps were called
assert model.training_step_called
assert model.training_step_end_called
assert model.training_epoch_end_called

# assert epoch end metrics were added
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
assert len(trainer.progress_bar_metrics) == 0

# make sure training outputs what is expected
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)

train_step_out = out.training_step_output_for_epoch_end
assert isinstance(train_step_out, torch.Tensor)
assert train_step_out.item() == 171

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171


def test_train_step_epoch_end_scalar(tmpdir):
"""
Checks train_step + training_epoch_end (NO training_step_end)
(with scalar return)
"""
model = DeterministicModel()
model.training_step = model.training_step_scalar_return
model.training_step_end = None
model.training_epoch_end = model.training_epoch_end_scalar
model.val_dataloader = None

trainer = Trainer(max_epochs=1, weights_summary=None)
trainer.fit(model)

# make sure correct steps were called
assert model.training_step_called
assert not model.training_step_end_called
assert model.training_epoch_end_called

# assert epoch end metrics were added
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
assert len(trainer.progress_bar_metrics) == 0

# make sure training outputs what is expected
for batch_idx, batch in enumerate(model.train_dataloader()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why you do not set directly?

batch_idx = 0
batch = model.train_dataloader()[0]

break

out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)

train_step_out = out.training_step_output_for_epoch_end
assert isinstance(train_step_out, torch.Tensor)
assert train_step_out.item() == 171

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171