Skip to content

Commit

Permalink
Count number of modules in train/eval mode in ModelSummary (#20159)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Aug 4, 2024
1 parent 60fe36a commit d4de8e2
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/source-pytorch/advanced/transfer_learning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ Here's a model that uses `Huggingface transformers <https://github.com/huggingfa
super().__init__()

self.bert = BertModel.from_pretrained("bert-base-cased", output_attentions=True)
self.bert.train()
self.W = nn.Linear(bert.config.hidden_size, 3)
self.num_classes = 3

Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- The `TQDMProgressBar` now provides an option to retain prior training epoch bars ([#19578](https://github.com/Lightning-AI/pytorch-lightning/pull/19578))

- Added the count of modules in train and eval mode to the printed `ModelSummary` table ([#20159](https://github.com/Lightning-AI/pytorch-lightning/pull/20159))

### Changed

- Triggering KeyboardInterrupt (Ctrl+C) during `.fit()`, `.evaluate()`, `.test()` or `.predict()` now terminates all processes launched by the Trainer and exits the program ([#19976](https://github.com/Lightning-AI/pytorch-lightning/pull/19976))
Expand Down
12 changes: 11 additions & 1 deletion src/lightning/pytorch/callbacks/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,17 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
total_parameters = model_summary.total_parameters
trainable_parameters = model_summary.trainable_parameters
model_size = model_summary.model_size
total_training_modes = model_summary.total_training_modes

if trainer.is_global_zero:
self.summarize(summary_data, total_parameters, trainable_parameters, model_size, **self._summarize_kwargs)
self.summarize(
summary_data,
total_parameters,
trainable_parameters,
model_size,
total_training_modes,
**self._summarize_kwargs,
)

def _summary(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Union[DeepSpeedSummary, Summary]:
from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
Expand All @@ -83,12 +91,14 @@ def summarize(
total_parameters: int,
trainable_parameters: int,
model_size: float,
total_training_modes: Dict[str, int],
**summarize_kwargs: Any,
) -> None:
summary_table = _format_summary_table(
total_parameters,
trainable_parameters,
model_size,
total_training_modes,
*summary_data,
)
log.info("\n" + summary_table)
5 changes: 4 additions & 1 deletion src/lightning/pytorch/callbacks/rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple
from typing import Any, Dict, List, Tuple

from typing_extensions import override

Expand Down Expand Up @@ -71,6 +71,7 @@ def summarize(
total_parameters: int,
trainable_parameters: int,
model_size: float,
total_training_modes: Dict[str, int],
**summarize_kwargs: Any,
) -> None:
from rich import get_console
Expand Down Expand Up @@ -110,5 +111,7 @@ def summarize(
grid.add_row(f"[bold]Non-trainable params[/]: {parameters[1]}")
grid.add_row(f"[bold]Total params[/]: {parameters[2]}")
grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}")
grid.add_row(f"[bold]Modules in train mode[/]: {total_training_modes['train']}")
grid.add_row(f"[bold]Modules in eval mode[/]: {total_training_modes['eval']}")

console.print(grid)
18 changes: 17 additions & 1 deletion src/lightning/pytorch/utilities/model_summary/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ class ModelSummary:
0 Non-trainable params
132 K Total params
0.530 Total estimated model params size (MB)
3 Modules in train mode
0 Modules in eval mode
>>> ModelSummary(model, max_depth=-1) # doctest: +NORMALIZE_WHITESPACE
| Name | Type | Params | Mode | In sizes | Out sizes
----------------------------------------------------------------------
Expand All @@ -198,6 +200,8 @@ class ModelSummary:
0 Non-trainable params
132 K Total params
0.530 Total estimated model params size (MB)
3 Modules in train mode
0 Modules in eval mode
"""

Expand Down Expand Up @@ -252,6 +256,12 @@ def param_nums(self) -> List[int]:
def training_modes(self) -> List[bool]:
return [layer.training for layer in self._layer_summary.values()]

@property
def total_training_modes(self) -> Dict[str, int]:
modes = [layer.training for layer in self._model.modules()]
modes = modes[1:] # exclude the root module
return {"train": modes.count(True), "eval": modes.count(False)}

@property
def total_parameters(self) -> int:
return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
Expand Down Expand Up @@ -351,8 +361,9 @@ def __str__(self) -> str:
total_parameters = self.total_parameters
trainable_parameters = self.trainable_parameters
model_size = self.model_size
total_training_modes = self.total_training_modes

return _format_summary_table(total_parameters, trainable_parameters, model_size, *arrays)
return _format_summary_table(total_parameters, trainable_parameters, model_size, total_training_modes, *arrays)

def __repr__(self) -> str:
return str(self)
Expand All @@ -372,6 +383,7 @@ def _format_summary_table(
total_parameters: int,
trainable_parameters: int,
model_size: float,
total_training_modes: Dict[str, int],
*cols: Tuple[str, List[str]],
) -> str:
"""Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big
Expand Down Expand Up @@ -408,6 +420,10 @@ def _format_summary_table(
summary += "Total params"
summary += "\n" + s.format(get_formatted_model_size(model_size), 10)
summary += "Total estimated model params size (MB)"
summary += "\n" + s.format(total_training_modes["train"], 10)
summary += "Modules in train mode"
summary += "\n" + s.format(total_training_modes["eval"], 10)
summary += "Modules in eval mode"

return summary

Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def on_train_epoch_end(self, trainer, pl_module):
self.saved_states.append(self.state_dict().copy())


@RunIf(sklearn=True)
@RunIf(sklearn=True, skip_windows=True) # Flaky test on Windows for unknown reasons
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_resume_early_stopping_from_checkpoint(tmp_path):
"""Prevent regressions to bugs:
Expand Down
3 changes: 3 additions & 0 deletions tests/tests_pytorch/callbacks/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def summarize(
total_parameters: int,
trainable_parameters: int,
model_size: float,
total_training_modes,
**summarize_kwargs: Any,
) -> None:
assert summary_data[1][0] == "Name"
Expand All @@ -64,6 +65,8 @@ def summarize(
assert summary_data[4][0] == "Mode"
assert summary_data[4][1][0] == "train"

assert total_training_modes == {"train": 1, "eval": 0}

model = BoringModel()
trainer = Trainer(default_root_dir=tmp_path, callbacks=CustomModelSummary(), max_steps=1)

Expand Down
8 changes: 7 additions & 1 deletion tests/tests_pytorch/callbacks/test_rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ def example_input_array(self) -> Any:
summary = summarize(model)
summary_data = summary._get_summary_data()

model_summary.summarize(summary_data=summary_data, total_parameters=1, trainable_parameters=1, model_size=1)
model_summary.summarize(
summary_data=summary_data,
total_parameters=1,
trainable_parameters=1,
model_size=1,
total_training_modes=summary.total_training_modes,
)

# ensure that summary was logged + the breakdown of model parameters
assert mock_console.call_count == 2
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
assert dm.my_state_dict == {"my": "state_dict"}


@RunIf(sklearn=True)
@RunIf(sklearn=True, skip_windows=True) # Flaky test on Windows for unknown reasons
def test_full_loop(tmp_path):
seed_everything(7)

Expand Down
28 changes: 27 additions & 1 deletion tests/tests_pytorch/utilities/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,29 @@ def forward(self, x):
assert not model.layer2.training


def test_total_training_modes():
"""Test that the `total_training_modes` counts the modules in 'train' and 'eval' mode, excluding the root
module."""

class ModelWithoutChildren(LightningModule):
pass

summary = ModelSummary(ModelWithoutChildren())
assert summary.total_training_modes == {"train": 0, "eval": 0}

model = DeepNestedModel()
summary = ModelSummary(model)
assert summary.total_training_modes == {"train": 19, "eval": 0}
assert sum(summary.total_training_modes.values()) == len(list(model.modules())) - 1

model = DeepNestedModel()
summary = ModelSummary(model)
model.branch1[1][0].eval()
model.branch2.eval()
assert summary.total_training_modes == {"train": 17, "eval": 2}
assert sum(summary.total_training_modes.values()) == len(list(model.modules())) - 1


def test_summary_training_mode():
"""Test that the model summary captures the training mode on all submodules."""
model = DeepNestedModel()
Expand All @@ -436,6 +459,7 @@ def test_summary_training_mode():
"eval", # branch2
"train", # head
]
assert summary.total_training_modes == {"train": 17, "eval": 2}

summary = summarize(model, max_depth=-1)
expected_eval = {"branch1.1.0", "branch2"}
Expand All @@ -445,5 +469,7 @@ def test_summary_training_mode():
# A model with params not belonging to a layer
model = NonLayerParamsModel()
model.layer.eval()
summary_data = OrderedDict(summarize(model)._get_summary_data())
summary = summarize(model)
summary_data = OrderedDict(summary._get_summary_data())
assert summary_data["Mode"] == ["eval", "n/a"]
assert summary.total_training_modes == {"train": 0, "eval": 1}

0 comments on commit d4de8e2

Please sign in to comment.