Skip to content

Commit

Permalink
Add enable_model_summary flag and deprecate weights_summary (Ligh…
Browse files Browse the repository at this point in the history
…tning-AI#9699)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kaushik B <kaushikbokka@gmail.com>
  • Loading branch information
5 people authored and rohitgr7 committed Oct 18, 2021
1 parent a6d1cc3 commit ee63840
Show file tree
Hide file tree
Showing 46 changed files with 248 additions and 134 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848))


- Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))


### Changed

- Module imports are now catching `ModuleNotFoundError` instead of `ImportError` ([#9867](https://github.com/PyTorchLightning/pytorch-lightning/pull/9867))
Expand Down Expand Up @@ -344,6 +347,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `pytorch_lightning.core.decorators.parameter_validation` in favor of `pytorch_lightning.utilities.parameter_tying.set_shared_parameters` ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))


- Deprecated passing `weights_summary` to the `Trainer` constructor in favor of adding the `ModelSummary` callback with `max_depth` directly to the list of callbacks ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))


### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/test_basic_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def lightning_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):
# as the first run is skipped, no need to run it long
max_epochs=num_epochs if idx > 0 else 1,
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
gpus=1 if device_type == "cuda" else 0,
checkpoint_callback=False,
logger=False,
Expand Down
12 changes: 8 additions & 4 deletions docs/source/common/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,14 @@ Print a summary of your LightningModule
---------------------------------------
Whenever the ``.fit()`` function gets called, the Trainer will print the weights summary for the LightningModule.
By default it only prints the top-level modules. If you want to show all submodules in your network, use the
`'full'` option:
``max_depth`` option:

.. testcode::

trainer = Trainer(weights_summary="full")
from pytorch_lightning.callbacks import ModelSummary

trainer = Trainer(callbacks=[ModelSummary(max_depth=-1)])


You can also display the intermediate input- and output sizes of all your layers by setting the
``example_input_array`` attribute in your LightningModule. It will print a table like this
Expand All @@ -115,8 +118,9 @@ You can also display the intermediate input- and output sizes of all your layers
when you call ``.fit()`` on the Trainer. This can help you find bugs in the composition of your layers.

See Also:
- :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_summary` Trainer argument
- :class:`~pytorch_lightning.core.memory.ModelSummary`
- :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary`
- :func:`~pytorch_lightning.utilities.model_summary.summarize`
- :class:`~pytorch_lightning.utilities.model_summary.ModelSummary`

----------------

Expand Down
24 changes: 24 additions & 0 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1589,6 +1589,11 @@ Example::
weights_summary
^^^^^^^^^^^^^^^

.. warning:: `weights_summary` is deprecated in v1.5 and will be removed in v1.7. Please pass :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary`
directly to the Trainer's ``callbacks`` argument instead. To disable the model summary,
pass ``enable_model_summary = False`` to the Trainer.


.. raw:: html

<video width="50%" max-width="400px" controls
Expand All @@ -1611,6 +1616,25 @@ Options: 'full', 'top', None.
# don't print a summary
trainer = Trainer(weights_summary=None)


enable_model_summary
^^^^^^^^^^^^^^^^^^^^

Whether to enable or disable the model summarization. Defaults to True.

.. testcode::

# default used by the Trainer
trainer = Trainer(enable_model_summary=True)

# disable summarization
trainer = Trainer(enable_model_summary=False)

# enable custom summarization
from pytorch_lightning.callbacks import ModelSummary

trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])

-----

Trainer class API
Expand Down
2 changes: 1 addition & 1 deletion pl_examples/bug_report_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def run():
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
weights_summary=None,
enable_model_summary=False,
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
trainer.test(model, dataloaders=test_data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def add_arguments_to_parser(self, parser):
parser.set_defaults(
{
"trainer.max_epochs": 15,
"trainer.weights_summary": None,
"trainer.enable_model_summary": False,
"trainer.num_sanity_val_steps": 0,
}
)
Expand Down
51 changes: 39 additions & 12 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def on_trainer_init(
process_position: int,
default_root_dir: Optional[str],
weights_save_path: Optional[str],
enable_model_summary: bool,
weights_summary: Optional[str],
stochastic_weight_avg: bool,
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
Expand Down Expand Up @@ -101,7 +102,7 @@ def on_trainer_init(
self.trainer._progress_bar_callback = None

# configure the ModelSummary callback
self._configure_model_summary_callback(weights_summary)
self._configure_model_summary_callback(enable_model_summary, weights_summary)

# accumulated grads
self._configure_accumulated_gradients(accumulate_grad_batches)
Expand Down Expand Up @@ -159,24 +160,50 @@ def _configure_checkpoint_callbacks(self, checkpoint_callback: Optional[bool], e
if not self._trainer_has_checkpoint_callbacks() and enable_checkpointing is True:
self.trainer.callbacks.append(ModelCheckpoint())

def _configure_model_summary_callback(self, weights_summary: Optional[str] = None) -> None:
if any(isinstance(cb, ModelSummary) for cb in self.trainer.callbacks):
def _configure_model_summary_callback(
self, enable_model_summary: bool, weights_summary: Optional[str] = None
) -> None:
if weights_summary is None:
rank_zero_deprecation(
"Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed"
" in v1.7. Please set `Trainer(enable_model_summary=False)` instead."
)
return
if not enable_model_summary:
return

model_summary_cbs = [type(cb) for cb in self.trainer.callbacks if isinstance(cb, ModelSummary)]
if model_summary_cbs:
rank_zero_info(
f"Trainer already configured with model summary callbacks: {model_summary_cbs}."
" Skipping setting a default `ModelSummary` callback."
)
return
if weights_summary is not None:

if weights_summary == "top":
# special case the default value for weights_summary to preserve backward compatibility
max_depth = 1
else:
rank_zero_deprecation(
f"Setting `Trainer(weights_summary={weights_summary})` is deprecated in v1.5 and will be removed"
" in v1.7. Please pass `pytorch_lightning.callbacks.model_summary.ModelSummary` with"
" `max_depth` directly to the Trainer's `callbacks` argument instead."
)
if weights_summary not in ModelSummaryMode.supported_types():
raise MisconfigurationException(
f"`weights_summary` can be None, {', '.join(ModelSummaryMode.supported_types())}",
f" but got {weights_summary}",
)
max_depth = ModelSummaryMode.get_max_depth(weights_summary)
if self.trainer._progress_bar_callback is not None and isinstance(
self.trainer._progress_bar_callback, RichProgressBar
):
model_summary = RichModelSummary(max_depth=max_depth)
else:
model_summary = ModelSummary(max_depth=max_depth)
self.trainer.callbacks.append(model_summary)
self.trainer.weights_summary = weights_summary

is_progress_bar_rich = isinstance(self.trainer._progress_bar_callback, RichProgressBar)

if self.trainer._progress_bar_callback is not None and is_progress_bar_rich:
model_summary = RichModelSummary(max_depth=max_depth)
else:
model_summary = ModelSummary(max_depth=max_depth)
self.trainer.callbacks.append(model_summary)
self.trainer._weights_summary = weights_summary

def _configure_swa_callbacks(self):
if not self.trainer._stochastic_weight_avg:
Expand Down
23 changes: 23 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(
accelerator: Optional[Union[str, Accelerator]] = None,
sync_batchnorm: bool = False,
precision: Union[int, str] = 32,
enable_model_summary: bool = True,
weights_summary: Optional[str] = "top",
weights_save_path: Optional[str] = None,
num_sanity_val_steps: int = 2,
Expand Down Expand Up @@ -370,8 +371,16 @@ def __init__(
val_check_interval: How often to check the validation set. Use float to check within a training epoch,
use int to check every n steps (batches).
enable_model_summary: Whether to enable model summarization by default.
weights_summary: Prints a summary of the weights when training begins.
.. deprecated:: v1.5
``weights_summary`` has been deprecated in v1.5 and will be removed in v1.7.
To disable the summary, pass ``enable_model_summary = False`` to the Trainer.
To customize the summary, pass :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary`
directly to the Trainer's ``callbacks`` argument.
weights_save_path: Where to save weights if specified. Will override default_root_dir
for checkpoints only. Use this if for whatever reason you need the checkpoints
stored in a different place than the logs written in `default_root_dir`.
Expand Down Expand Up @@ -463,6 +472,9 @@ def __init__(
self.tested_ckpt_path: Optional[str] = None
self.predicted_ckpt_path: Optional[str] = None

# todo: remove in v1.7
self._weights_summary: Optional[str] = None

# init callbacks
# Declare attributes to be set in callback_connector on_trainer_init
self.callback_connector.on_trainer_init(
Expand All @@ -474,6 +486,7 @@ def __init__(
process_position,
default_root_dir,
weights_save_path,
enable_model_summary,
weights_summary,
stochastic_weight_avg,
max_time,
Expand Down Expand Up @@ -2023,6 +2036,16 @@ def _exit_gracefully_on_signal(self) -> None:
class_name = caller[0].f_locals["self"].__class__.__name__
raise ExitGracefullyException(f"Exiting gracefully on {class_name}:{caller.function}")

@property
def weights_summary(self) -> Optional[str]:
rank_zero_deprecation("`Trainer.weights_summary` is deprecated in v1.5 and will be removed in v1.7.")
return self._weights_summary

@weights_summary.setter
def weights_summary(self, val: Optional[str]) -> None:
rank_zero_deprecation("Setting `Trainer.weights_summary` is deprecated in v1.5 and will be removed in v1.7.")
self._weights_summary = val

"""
Other
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/accelerators/test_multi_nodes_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def validation_step(self, batch, batch_idx):
limit_train_batches=1,
limit_val_batches=1,
max_epochs=2,
weights_summary=None,
enable_model_summary=False,
accelerator="ddp",
gpus=1,
num_nodes=2,
Expand Down Expand Up @@ -101,7 +101,7 @@ def backward(self, loss, optimizer, optimizer_idx):
limit_val_batches=2,
max_epochs=2,
log_every_n_steps=1,
weights_summary=None,
enable_model_summary=False,
accelerator="ddp",
gpus=1,
num_nodes=2,
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_callback_hook_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def training_epoch_end(self, outputs) -> None:
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
enable_model_summary=False,
)

assert any(isinstance(c, CB) for c in trainer.callbacks)
Expand All @@ -74,7 +74,7 @@ def on_epoch_end(self, trainer, pl_module):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
weights_summary=None,
enable_model_summary=False,
)

trainer.fit(model)
6 changes: 3 additions & 3 deletions tests/callbacks/test_gradient_accumulation_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_trainer_accumulate_grad_batches_zero_grad(tmpdir, accumulate_grad_batch
limit_train_batches=20,
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
enable_model_summary=False,
accumulate_grad_batches=accumulate_grad_batches,
)
assert trainer.accumulate_grad_batches == accumulate_grad_batches
Expand All @@ -56,7 +56,7 @@ def test_trainer_accumulate_grad_batches_dict_zero_grad(tmpdir, accumulate_grad_
limit_train_batches=10,
limit_val_batches=1,
max_epochs=4,
weights_summary=None,
enable_model_summary=False,
accumulate_grad_batches=accumulate_grad_batches,
)
assert trainer.accumulate_grad_batches == accumulate_grad_batches.get(0, 1)
Expand All @@ -74,7 +74,7 @@ def test_trainer_accumulate_grad_batches_with_callback(tmpdir):
limit_train_batches=10,
limit_val_batches=1,
max_epochs=4,
weights_summary=None,
enable_model_summary=False,
callbacks=[GradientAccumulationScheduler({1: 2, 3: 4})],
)
assert trainer.accumulate_grad_batches == 1
Expand Down
8 changes: 4 additions & 4 deletions tests/callbacks/test_lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def configure_optimizers(self):
limit_train_batches=0.5,
callbacks=[lr_monitor],
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
)
trainer.fit(TestModel())
assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ["my_logging_name"]
Expand All @@ -273,7 +273,7 @@ def configure_optimizers(self):
limit_train_batches=2,
callbacks=[lr_monitor],
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
)
trainer.fit(TestModel())
assert lr_monitor.lr_sch_names == ["lr-SGD"]
Expand Down Expand Up @@ -311,7 +311,7 @@ def configure_optimizers(self):
limit_train_batches=2,
callbacks=[lr_monitor],
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
)

with pytest.raises(
Expand Down Expand Up @@ -389,7 +389,7 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int):
limit_train_batches=2,
callbacks=[TestFinetuning(), lr_monitor, Check()],
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
enable_checkpointing=False,
)
model = TestModel()
Expand Down
9 changes: 9 additions & 0 deletions tests/callbacks/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ def test_model_summary_callback_with_weights_summary_none():
trainer = Trainer(weights_summary=None)
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

trainer = Trainer(enable_model_summary=False)
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

trainer = Trainer(enable_model_summary=False, weights_summary="full")
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

trainer = Trainer(enable_model_summary=True, weights_summary=None)
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)


def test_model_summary_callback_with_weights_summary():

Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def _test_progress_bar_max_val_check_interval(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
max_epochs=1,
weights_summary=None,
enable_model_summary=False,
val_check_interval=val_check_interval,
gpus=world_size,
accelerator="ddp",
Expand Down
Loading

0 comments on commit ee63840

Please sign in to comment.