Skip to content

Commit

Permalink
Fix weights_summary
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Jan 1, 2021
1 parent ccffc34 commit d6bc41b
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 15 deletions.
14 changes: 11 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,9 +1330,17 @@ def tbptt_split_batch(self, batch, split_size):

return splits

def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary:
model_summary = ModelSummary(self, mode=mode)
log.info("\n" + str(model_summary))
def summarize(self, weights_summary: str = ModelSummary.MODE_DEFAULT) -> ModelSummary:
model_summary = None

if weights_summary in ModelSummary.MODES:
model_summary = ModelSummary(self, mode=weights_summary)
log.info("\n" + str(model_summary))
elif weights_summary is not None:
raise MisconfigurationException(
f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, got {weights_summary}"
)

return model_summary

def freeze(self) -> None:
Expand Down
7 changes: 2 additions & 5 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,8 @@ def setup_training(self, model: LightningModule):
ref_model.on_pretrain_routine_start()

# print model summary
if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.testing:
if self.trainer.weights_summary in ModelSummary.MODES:
ref_model.summarize(mode=self.trainer.weights_summary)
else:
raise MisconfigurationException("weights_summary can be None, " + ", ".join(ModelSummary.MODES))
if self.trainer.is_global_zero and not self.trainer.testing:
ref_model.summarize(weights_summary=self.trainer.weights_summary)

# track model now.
# if cluster resets state, the model will update with the saved weights
Expand Down
21 changes: 14 additions & 7 deletions tests/core/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import torch.nn as nn

from pytorch_lightning import LightningModule
from pytorch_lightning.core.memory import UNKNOWN_SIZE, ModelSummary
from pytorch_lightning.core.memory import ModelSummary, UNKNOWN_SIZE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.models import ParityModuleRNN


Expand Down Expand Up @@ -68,14 +69,20 @@ def forward(self, x):
return self.reduce(self.embed(x))


def test_invalid_weights_summmary():
""" Test that invalid value for weights_summary raises an error. """
with pytest.raises(MisconfigurationException, match='can be None, .* got temp'):
UnorderedModel().summarize(weights_summary='temp')


@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
])
def test_empty_model_summary_shapes(mode):
""" Test that the summary works for models that have no submodules. """
model = EmptyModule()
summary = model.summarize(mode=mode)
summary = model.summarize(weights_summary=mode)
assert summary.in_sizes == []
assert summary.out_sizes == []
assert summary.param_nums == []
Expand All @@ -95,7 +102,7 @@ def test_linear_model_summary_shapes(device, mode):
""" Test that the model summary correctly computes the input- and output shapes. """
model = UnorderedModel().to(device)
model.train()
summary = model.summarize(mode=mode)
summary = model.summarize(weights_summary=mode)
assert summary.in_sizes == [
[2, 10], # layer 2
[2, 7], # combine
Expand Down Expand Up @@ -158,7 +165,7 @@ def test_rnn_summary_shapes(mode):

model.example_input_array = torch.zeros(b, t, 10)

summary = model.summarize(mode=mode)
summary = model.summarize(weights_summary=mode)
assert summary.in_sizes == [
[b, t, i], # rnn
[b, t, h], # linear
Expand All @@ -176,7 +183,7 @@ def test_rnn_summary_shapes(mode):
def test_summary_parameter_count(mode):
""" Test that the summary counts the number of parameters in every submodule. """
model = UnorderedModel()
summary = model.summarize(mode=mode)
summary = model.summarize(weights_summary=mode)
assert summary.param_nums == [
model.layer2.weight.numel() + model.layer2.bias.numel(),
model.combine.weight.numel() + model.combine.bias.numel(),
Expand All @@ -193,7 +200,7 @@ def test_summary_parameter_count(mode):
def test_summary_layer_types(mode):
""" Test that the summary displays the layer names correctly. """
model = UnorderedModel()
summary = model.summarize(mode=mode)
summary = model.summarize(weights_summary=mode)
assert summary.layer_types == [
'Linear',
'Linear',
Expand Down Expand Up @@ -235,5 +242,5 @@ def forward(self, *args, **kwargs):

model = DummyLightningModule()
model.example_input_array = example_input
summary = model.summarize(mode=mode)
summary = model.summarize(weights_summary=mode)
assert summary.in_sizes == [expected_size]

0 comments on commit d6bc41b

Please sign in to comment.