From c39ac09e220f70234096877bc4dcd4897a618d5b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 30 Dec 2020 02:35:27 +0530 Subject: [PATCH 1/5] Fix weights_summary --- pytorch_lightning/core/lightning.py | 14 +++++++++++--- pytorch_lightning/trainer/training_loop.py | 7 ++----- tests/core/test_memory.py | 21 ++++++++++++++------- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a4330b401936d..e4d4b9d7b26ab 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1327,9 +1327,17 @@ def tbptt_split_batch(self, batch, split_size): return splits - def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary: - model_summary = ModelSummary(self, mode=mode) - log.info("\n" + str(model_summary)) + def summarize(self, weights_summary: str = ModelSummary.MODE_DEFAULT) -> ModelSummary: + model_summary = None + + if weights_summary in ModelSummary.MODES: + model_summary = ModelSummary(self, mode=weights_summary) + log.info("\n" + str(model_summary)) + elif weights_summary is not None: + raise MisconfigurationException( + f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, got {weights_summary}" + ) + return model_summary def freeze(self) -> None: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index fe4525006ebb9..ee92e5ec8581a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -161,11 +161,8 @@ def setup_training(self, model: LightningModule): ref_model.on_pretrain_routine_start() # print model summary - if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.testing: - if self.trainer.weights_summary in ModelSummary.MODES: - ref_model.summarize(mode=self.trainer.weights_summary) - else: - raise MisconfigurationException("weights_summary can be None, " + ", ".join(ModelSummary.MODES)) + if self.trainer.is_global_zero and not self.trainer.testing: + ref_model.summarize(weights_summary=self.trainer.weights_summary) # track model now. # if cluster resets state, the model will update with the saved weights diff --git a/tests/core/test_memory.py b/tests/core/test_memory.py index cfeb302134d24..1f6b1a789ed5d 100644 --- a/tests/core/test_memory.py +++ b/tests/core/test_memory.py @@ -16,7 +16,8 @@ import torch.nn as nn from pytorch_lightning import LightningModule -from pytorch_lightning.core.memory import UNKNOWN_SIZE, ModelSummary +from pytorch_lightning.core.memory import ModelSummary, UNKNOWN_SIZE +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.models import ParityModuleRNN @@ -68,6 +69,12 @@ def forward(self, x): return self.reduce(self.embed(x)) +def test_invalid_weights_summmary(): + """ Test that invalid value for weights_summary raises an error. """ + with pytest.raises(MisconfigurationException, match='can be None, .* got temp'): + UnorderedModel().summarize(weights_summary='temp') + + @pytest.mark.parametrize(['mode'], [ pytest.param(ModelSummary.MODE_FULL), pytest.param(ModelSummary.MODE_TOP), @@ -75,7 +82,7 @@ def forward(self, x): def test_empty_model_summary_shapes(mode): """ Test that the summary works for models that have no submodules. """ model = EmptyModule() - summary = model.summarize(mode=mode) + summary = model.summarize(weights_summary=mode) assert summary.in_sizes == [] assert summary.out_sizes == [] assert summary.param_nums == [] @@ -95,7 +102,7 @@ def test_linear_model_summary_shapes(device, mode): """ Test that the model summary correctly computes the input- and output shapes. """ model = UnorderedModel().to(device) model.train() - summary = model.summarize(mode=mode) + summary = model.summarize(weights_summary=mode) assert summary.in_sizes == [ [2, 10], # layer 2 [2, 7], # combine @@ -158,7 +165,7 @@ def test_rnn_summary_shapes(mode): model.example_input_array = torch.zeros(b, t, 10) - summary = model.summarize(mode=mode) + summary = model.summarize(weights_summary=mode) assert summary.in_sizes == [ [b, t, i], # rnn [b, t, h], # linear @@ -176,7 +183,7 @@ def test_rnn_summary_shapes(mode): def test_summary_parameter_count(mode): """ Test that the summary counts the number of parameters in every submodule. """ model = UnorderedModel() - summary = model.summarize(mode=mode) + summary = model.summarize(weights_summary=mode) assert summary.param_nums == [ model.layer2.weight.numel() + model.layer2.bias.numel(), model.combine.weight.numel() + model.combine.bias.numel(), @@ -193,7 +200,7 @@ def test_summary_parameter_count(mode): def test_summary_layer_types(mode): """ Test that the summary displays the layer names correctly. """ model = UnorderedModel() - summary = model.summarize(mode=mode) + summary = model.summarize(weights_summary=mode) assert summary.layer_types == [ 'Linear', 'Linear', @@ -235,5 +242,5 @@ def forward(self, *args, **kwargs): model = DummyLightningModule() model.example_input_array = example_input - summary = model.summarize(mode=mode) + summary = model.summarize(weights_summary=mode) assert summary.in_sizes == [expected_size] From 3f4f4bdef1f51daea0b11c9a3b48d192d7175e46 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 2 Jan 2021 19:58:45 +0530 Subject: [PATCH 2/5] use mode --- pytorch_lightning/core/lightning.py | 10 +++++----- pytorch_lightning/trainer/trainer.py | 4 ++-- pytorch_lightning/trainer/training_loop.py | 15 +++++++++++++- tests/core/test_memory.py | 23 ++++++++++++---------- 4 files changed, 34 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e4d4b9d7b26ab..193abd5263362 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1327,15 +1327,15 @@ def tbptt_split_batch(self, batch, split_size): return splits - def summarize(self, weights_summary: str = ModelSummary.MODE_DEFAULT) -> ModelSummary: + def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary: model_summary = None - if weights_summary in ModelSummary.MODES: - model_summary = ModelSummary(self, mode=weights_summary) + if mode in ModelSummary.MODES: + model_summary = ModelSummary(self, mode=mode) log.info("\n" + str(model_summary)) - elif weights_summary is not None: + elif mode is not None: raise MisconfigurationException( - f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, got {weights_summary}" + f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}" ) return model_summary diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c66cc3a43d0b1..42336c0c60972 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -309,7 +309,6 @@ def __init__( self.plugin_connector = PluginConnector(self) # training state - self.weights_summary = weights_summary self.model = None self.shown_warnings = set() @@ -372,7 +371,8 @@ def __init__( max_steps, min_steps, num_sanity_val_steps, - automatic_optimization + automatic_optimization, + weights_summary, ) self.evaluation_loop.on_trainer_init() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ee92e5ec8581a..c183c15f01a3c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -49,7 +49,14 @@ def __init__(self, trainer): self._cur_grad_norm_dict = None def on_trainer_init( - self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps, automatic_optimization + self, + max_epochs, + min_epochs, + max_steps, + min_steps, + num_sanity_val_steps, + automatic_optimization, + weights_summary, ): self.trainer.global_step = 0 self.trainer.current_epoch = 0 @@ -73,6 +80,12 @@ def on_trainer_init( else: self.trainer.num_sanity_val_steps = num_sanity_val_steps + self.trainer.weights_summary = weights_summary + if weights_summary is not None and weights_summary not in ModelSummary.MODES: + raise MisconfigurationException( + f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, got {weights_summary}" + ) + @property def num_optimizers(self): num_optimizers = len(self.get_optimizers_iterable()) diff --git a/tests/core/test_memory.py b/tests/core/test_memory.py index 1f6b1a789ed5d..142159fa48fd8 100644 --- a/tests/core/test_memory.py +++ b/tests/core/test_memory.py @@ -15,8 +15,8 @@ import torch import torch.nn as nn -from pytorch_lightning import LightningModule -from pytorch_lightning.core.memory import ModelSummary, UNKNOWN_SIZE +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.core.memory import UNKNOWN_SIZE, ModelSummary from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.models import ParityModuleRNN @@ -71,8 +71,11 @@ def forward(self, x): def test_invalid_weights_summmary(): """ Test that invalid value for weights_summary raises an error. """ - with pytest.raises(MisconfigurationException, match='can be None, .* got temp'): - UnorderedModel().summarize(weights_summary='temp') + with pytest.raises(MisconfigurationException, match='`mode` can be None, .* got temp'): + UnorderedModel().summarize(mode='temp') + + with pytest.raises(MisconfigurationException, match='`weights_summary` can be None, .* got temp'): + Trainer(weights_summary='temp') @pytest.mark.parametrize(['mode'], [ @@ -82,7 +85,7 @@ def test_invalid_weights_summmary(): def test_empty_model_summary_shapes(mode): """ Test that the summary works for models that have no submodules. """ model = EmptyModule() - summary = model.summarize(weights_summary=mode) + summary = model.summarize(mode=mode) assert summary.in_sizes == [] assert summary.out_sizes == [] assert summary.param_nums == [] @@ -102,7 +105,7 @@ def test_linear_model_summary_shapes(device, mode): """ Test that the model summary correctly computes the input- and output shapes. """ model = UnorderedModel().to(device) model.train() - summary = model.summarize(weights_summary=mode) + summary = model.summarize(mode=mode) assert summary.in_sizes == [ [2, 10], # layer 2 [2, 7], # combine @@ -165,7 +168,7 @@ def test_rnn_summary_shapes(mode): model.example_input_array = torch.zeros(b, t, 10) - summary = model.summarize(weights_summary=mode) + summary = model.summarize(mode=mode) assert summary.in_sizes == [ [b, t, i], # rnn [b, t, h], # linear @@ -183,7 +186,7 @@ def test_rnn_summary_shapes(mode): def test_summary_parameter_count(mode): """ Test that the summary counts the number of parameters in every submodule. """ model = UnorderedModel() - summary = model.summarize(weights_summary=mode) + summary = model.summarize(mode=mode) assert summary.param_nums == [ model.layer2.weight.numel() + model.layer2.bias.numel(), model.combine.weight.numel() + model.combine.bias.numel(), @@ -200,7 +203,7 @@ def test_summary_parameter_count(mode): def test_summary_layer_types(mode): """ Test that the summary displays the layer names correctly. """ model = UnorderedModel() - summary = model.summarize(weights_summary=mode) + summary = model.summarize(mode=mode) assert summary.layer_types == [ 'Linear', 'Linear', @@ -242,5 +245,5 @@ def forward(self, *args, **kwargs): model = DummyLightningModule() model.example_input_array = example_input - summary = model.summarize(weights_summary=mode) + summary = model.summarize(mode=mode) assert summary.in_sizes == [expected_size] From 03f8684d77de7a36979fddd97e1a4ccabd6c7645 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 2 Jan 2021 21:03:14 +0530 Subject: [PATCH 3/5] fix --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c183c15f01a3c..dacb57302fdd8 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -175,7 +175,7 @@ def setup_training(self, model: LightningModule): # print model summary if self.trainer.is_global_zero and not self.trainer.testing: - ref_model.summarize(weights_summary=self.trainer.weights_summary) + ref_model.summarize(mode=self.trainer.weights_summary) # track model now. # if cluster resets state, the model will update with the saved weights From 479e21ff215473ddbe4dda6bb88cc139d8d81578 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 4 Jan 2021 23:34:16 +0530 Subject: [PATCH 4/5] optional --- pytorch_lightning/core/lightning.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 193abd5263362..69a9da20afe74 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -14,15 +14,15 @@ """nn.Module with additional great features.""" -from abc import ABC -from argparse import Namespace import collections import copy import inspect import os -from pathlib import Path import re import tempfile +from abc import ABC +from argparse import Namespace +from pathlib import Path from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import torch @@ -1327,7 +1327,7 @@ def tbptt_split_batch(self, batch, split_size): return splits - def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary: + def summarize(self, mode: str = Optional[ModelSummary.MODE_DEFAULT]) -> Optional[ModelSummary]: model_summary = None if mode in ModelSummary.MODES: From 46d97b19da0a17bf9ec3f394fd7493f9acd80891 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 4 Jan 2021 23:53:57 +0530 Subject: [PATCH 5/5] what was I thinking --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 69a9da20afe74..421fc5e5cf2ac 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1327,7 +1327,7 @@ def tbptt_split_batch(self, batch, split_size): return splits - def summarize(self, mode: str = Optional[ModelSummary.MODE_DEFAULT]) -> Optional[ModelSummary]: + def summarize(self, mode: Optional[str] = ModelSummary.MODE_DEFAULT) -> Optional[ModelSummary]: model_summary = None if mode in ModelSummary.MODES: