Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ModelSummary Callback #9344

Merged
merged 17 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `inference_mode` for evaluation and prediction ([8813](https://github.com/PyTorchLightning/pytorch-lightning/pull/8813))


- Added `ModelSummary` callback ([#9344](https://github.com/PyTorchLightning/pytorch-lightning/pull/9344))


### Changed

- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pytorch_lightning.callbacks.lambda_function import LambdaCallback
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.model_summary import ModelSummary
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar
from pytorch_lightning.callbacks.pruning import ModelPruning
Expand All @@ -39,6 +40,7 @@
"LearningRateMonitor",
"ModelCheckpoint",
"ModelPruning",
"ModelSummary",
"BasePredictionWriter",
"ProgressBar",
"ProgressBarBase",
Expand Down
72 changes: 72 additions & 0 deletions pytorch_lightning/callbacks/model_summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Model Summary
=============

Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.

The string representation of this summary prints a table with columns containing
the name, type and number of parameters for each layer.

"""
import logging
from typing import List, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities.model_summary import _format_summary_table, summarize

log = logging.getLogger(__name__)


class ModelSummary(Callback):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
r"""
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.

Args:
max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the
layer summary off. Default: 1.
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

Example::

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import ModelSummary
>>> trainer = Trainer(callbacks=[ModelSummary(max_depth=1)])
"""

def __init__(self, max_depth: Optional[int] = 1):
self._max_depth: int = max_depth

def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if trainer.is_global_zero and self._max_depth is not None and not trainer.testing:
model_summary = summarize(pl_module, max_depth=self._max_depth)

summary_data = model_summary._get_summary_data()
total_parameters = model_summary.total_parameters
trainable_parameters = model_summary.trainable_parameters
model_size = model_summary.model_size

self.summarize(summary_data, total_parameters, trainable_parameters, model_size)
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def summarize(
summary_data: List[List[Union[str, List[str]]]],
total_parameters: int,
trainable_parameters: int,
model_size: float,
) -> None:
summary_table = _format_summary_table(total_parameters, trainable_parameters, model_size, *summary_data)

log.info("\n" + summary_table)
20 changes: 18 additions & 2 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from datetime import timedelta
from typing import Dict, List, Optional, Union

from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ModelSummary, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import rank_zero_deprecation

Expand All @@ -34,6 +34,7 @@ def on_trainer_init(
process_position: int,
default_root_dir: Optional[str],
weights_save_path: Optional[str],
weights_summary: Optional[str],
stochastic_weight_avg: bool,
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
):
Expand All @@ -58,6 +59,8 @@ def on_trainer_init(
# responsible to stop the training when max_time is reached.
self._configure_timer_callback(max_time)

self._configure_model_summary_callback(weights_summary)

# init progress bar
if process_position != 0:
rank_zero_deprecation(
Expand Down Expand Up @@ -89,6 +92,19 @@ def _configure_checkpoint_callbacks(self, checkpoint_callback: bool) -> None:
if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
self.trainer.callbacks.append(ModelCheckpoint())

def _configure_model_summary_callback(self, weights_summary: Optional[str] = None) -> None:
if any(isinstance(cb, ModelSummary) for cb in self.trainer.callbacks):
return
if weights_summary is not None:
if weights_summary not in ModelSummaryMode.supported_types():
raise MisconfigurationException(
f"`weights_summary` can be None, {', '.join(ModelSummaryMode.supported_types())}",
f" but got {weights_summary}",
)
max_depth = ModelSummaryMode.get_max_depth(weights_summary)
model_summary = ModelSummary(max_depth=max_depth)
self.trainer.callbacks.append(model_summary)

def _configure_swa_callbacks(self):
if not self.trainer._stochastic_weight_avg:
return
Expand Down
12 changes: 1 addition & 11 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_9
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS

Expand Down Expand Up @@ -407,11 +406,6 @@ def __init__(
# default .predict() loop
self.predict_loop = PredictionLoop()

# training state
if weights_summary is not None and weights_summary not in ModelSummary.MODES:
raise MisconfigurationException(
f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, but got {weights_summary}"
)
self.weights_summary = weights_summary

# init callbacks
Expand All @@ -423,6 +417,7 @@ def __init__(
process_position,
default_root_dir,
weights_save_path,
self.weights_summary,
stochastic_weight_avg,
max_time,
)
Expand Down Expand Up @@ -1108,11 +1103,6 @@ def _pre_training_routine(self):
# --------------------------
self.call_hook("on_pretrain_routine_start")

# print model summary
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
if self.is_global_zero and self.weights_summary is not None and not self.testing:
max_depth = ModelSummary.MODES[self.weights_summary]
summarize(self.lightning_module, max_depth=max_depth)

self.call_hook("on_pretrain_routine_end")

def _run_train(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DistributedType,
GradClipAlgorithmType,
LightningEnum,
ModelSummaryMode,
)
from pytorch_lightning.utilities.grads import grad_norm # noqa: F401
from pytorch_lightning.utilities.imports import ( # noqa: F401
Expand Down
34 changes: 33 additions & 1 deletion pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def supported_types() -> List[str]:


class DistributedType(LightningEnum):
"""Define type of ditributed computing.
"""Define type of distributed computing.

>>> # you can match the type with string
>>> DistributedType.DDP == 'ddp'
Expand Down Expand Up @@ -147,3 +147,35 @@ class AutoRestartBatchKeys(LightningEnum):
"""Defines special dictionary keys used to track captured dataset state with multiple workers."""

PL_RESTART_META = "__pl_restart_meta"


class ModelSummaryMode(LightningEnum):
# TODO: remove in v1.6 (as `mode` would be deprecated for `max_depth`)
"""Define the Model Summary mode to be used.

Can be one of
- `top`: only the top-level modules will be recorded (the children of the root module)
- `full`: summarizes all layers and their submodules in the root module

>>> # you can match the type with string
>>> ModelSummaryMode.TOP == 'TOP'
True
>>> # which is case invariant
>>> ModelSummaryMode.TOP in ('top', 'FULL')
True
"""

TOP = "top"
FULL = "full"

@staticmethod
def get_max_depth(mode: str) -> int:
if mode == ModelSummaryMode.TOP:
return 1
if mode == ModelSummaryMode.FULL:
return -1
raise ValueError(f"`mode` can be {', '.join(list(ModelSummaryMode))}, got {mode}.")

@staticmethod
def supported_types() -> List[str]:
return [x.value for x in ModelSummaryMode]
28 changes: 17 additions & 11 deletions pytorch_lightning/utilities/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.utils.hooks import RemovableHandle

import pytorch_lightning as pl
from pytorch_lightning.utilities import AMPType, DeviceType, rank_zero_deprecation
from pytorch_lightning.utilities import AMPType, DeviceType, ModelSummaryMode, rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.warnings import WarningCache
Expand Down Expand Up @@ -185,21 +185,21 @@ class ModelSummary:
0.530 Total estimated model params size (MB)
"""

MODES = dict(top=1, full=-1) # TODO: remove in v1.6

def __init__(self, model, mode: Optional[str] = None, max_depth: Optional[int] = 1):
self._model = model

# temporary mapping from mode to max_depth
if max_depth is None or mode is not None:
if mode in ModelSummary.MODES:
max_depth = ModelSummary.MODES[mode]
if mode in ModelSummaryMode.supported_types():
max_depth = ModelSummaryMode.get_max_depth(mode)
rank_zero_deprecation(
"Argument `mode` in `ModelSummary` is deprecated in v1.4"
f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behaviour."
)
else:
raise MisconfigurationException(f"`mode` can be {', '.join(ModelSummary.MODES)}, got {mode}.")
raise MisconfigurationException(
f"`mode` can be {', '.join(ModelSummaryMode.supported_types())}, got {mode}."
)

if not isinstance(max_depth, int) or max_depth < -1:
raise ValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.")
Expand Down Expand Up @@ -295,7 +295,7 @@ def _forward_example_input(self) -> None:
model(input_)
model.train(mode) # restore mode of module

def __str__(self):
def _get_summary_data(self):
"""Makes a summary listing with:

Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size
Expand All @@ -310,6 +310,11 @@ def __str__(self):
arrays.append(["In sizes", self.in_sizes])
arrays.append(["Out sizes", self.out_sizes])

return arrays

def __str__(self):
arrays = self._get_summary_data()

total_parameters = self.total_parameters
trainable_parameters = self.trainable_parameters
model_size = self.model_size
Expand Down Expand Up @@ -445,16 +450,17 @@ def summarize(

# temporary mapping from mode to max_depth
if max_depth is None:
if mode in ModelSummary.MODES:
max_depth = ModelSummary.MODES[mode]
if mode in ModelSummaryMode.supported_types():
max_depth = ModelSummaryMode.get_max_depth(mode)
rank_zero_deprecation(
"Argument `mode` in `LightningModule.summarize` is deprecated in v1.4"
f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behavior."
)
model_summary = ModelSummary(lightning_module, max_depth=max_depth)
elif mode is not None:
raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}")
raise MisconfigurationException(
f"`mode` can be None, {', '.join(ModelSummaryMode.supported_types())}, got {mode}"
)
else:
model_summary = ModelSummary(lightning_module, max_depth=max_depth)
log.info("\n" + str(model_summary))
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
return model_summary
Loading