Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add enable_model_summary flag and deprecate weights_summary #9699

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -181,6 +181,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848))


- Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))


### Changed

- Module imports are now catching `ModuleNotFoundError` instead of `ImportError` ([#9867](https://github.com/PyTorchLightning/pytorch-lightning/pull/9867))
@@ -344,6 +347,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `pytorch_lightning.core.decorators.parameter_validation` in favor of `pytorch_lightning.utilities.parameter_tying.set_shared_parameters` ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))


- Deprecated passing `weights_summary` to the `Trainer` constructor in favor of adding the `ModelSummary` callback with `max_depth` directly to the list of callbacks ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))


ananthsub marked this conversation as resolved.
Show resolved Hide resolved
### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
2 changes: 1 addition & 1 deletion benchmarks/test_basic_parity.py
Original file line number Diff line number Diff line change
@@ -159,7 +159,7 @@ def lightning_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):
# as the first run is skipped, no need to run it long
max_epochs=num_epochs if idx > 0 else 1,
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
gpus=1 if device_type == "cuda" else 0,
checkpoint_callback=False,
logger=False,
12 changes: 8 additions & 4 deletions docs/source/common/debugging.rst
Original file line number Diff line number Diff line change
@@ -95,11 +95,14 @@ Print a summary of your LightningModule
---------------------------------------
Whenever the ``.fit()`` function gets called, the Trainer will print the weights summary for the LightningModule.
By default it only prints the top-level modules. If you want to show all submodules in your network, use the
`'full'` option:
``max_depth`` option:

.. testcode::

trainer = Trainer(weights_summary="full")
from pytorch_lightning.callbacks import ModelSummary

trainer = Trainer(callbacks=[ModelSummary(max_depth=-1)])


You can also display the intermediate input- and output sizes of all your layers by setting the
``example_input_array`` attribute in your LightningModule. It will print a table like this
@@ -115,8 +118,9 @@ You can also display the intermediate input- and output sizes of all your layers
when you call ``.fit()`` on the Trainer. This can help you find bugs in the composition of your layers.

See Also:
- :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_summary` Trainer argument
- :class:`~pytorch_lightning.core.memory.ModelSummary`
- :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary`
- :func:`~pytorch_lightning.utilities.model_summary.summarize`
- :class:`~pytorch_lightning.utilities.model_summary.ModelSummary`

----------------

24 changes: 24 additions & 0 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
@@ -1589,6 +1589,11 @@ Example::
weights_summary
^^^^^^^^^^^^^^^

.. warning:: `weights_summary` is deprecated in v1.5 and will be removed in v1.7. Please pass :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary`
directly to the Trainer's ``callbacks`` argument instead. To disable the model summary,
pass ``enable_model_summary = False`` to the Trainer.


.. raw:: html

<video width="50%" max-width="400px" controls
@@ -1611,6 +1616,25 @@ Options: 'full', 'top', None.
# don't print a summary
trainer = Trainer(weights_summary=None)


enable_model_summary
^^^^^^^^^^^^^^^^^^^^

Whether to enable or disable the model summarization. Defaults to True.

.. testcode::

# default used by the Trainer
trainer = Trainer(enable_model_summary=True)

# disable summarization
trainer = Trainer(enable_model_summary=False)
ananthsub marked this conversation as resolved.
Show resolved Hide resolved

# enable custom summarization
from pytorch_lightning.callbacks import ModelSummary

trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])

-----

Trainer class API
2 changes: 1 addition & 1 deletion pl_examples/bug_report_model.py
Original file line number Diff line number Diff line change
@@ -55,7 +55,7 @@ def run():
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
weights_summary=None,
enable_model_summary=False,
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
trainer.test(model, dataloaders=test_data)
Original file line number Diff line number Diff line change
@@ -272,7 +272,7 @@ def add_arguments_to_parser(self, parser):
parser.set_defaults(
{
"trainer.max_epochs": 15,
"trainer.weights_summary": None,
"trainer.enable_model_summary": False,
"trainer.num_sanity_val_steps": 0,
}
)
51 changes: 39 additions & 12 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@ def on_trainer_init(
process_position: int,
default_root_dir: Optional[str],
weights_save_path: Optional[str],
enable_model_summary: bool,
weights_summary: Optional[str],
stochastic_weight_avg: bool,
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
@@ -101,7 +102,7 @@ def on_trainer_init(
self.trainer._progress_bar_callback = None

# configure the ModelSummary callback
self._configure_model_summary_callback(weights_summary)
self._configure_model_summary_callback(enable_model_summary, weights_summary)

# accumulated grads
self._configure_accumulated_gradients(accumulate_grad_batches)
@@ -159,24 +160,50 @@ def _configure_checkpoint_callbacks(self, checkpoint_callback: Optional[bool], e
if not self._trainer_has_checkpoint_callbacks() and enable_checkpointing is True:
self.trainer.callbacks.append(ModelCheckpoint())

def _configure_model_summary_callback(self, weights_summary: Optional[str] = None) -> None:
if any(isinstance(cb, ModelSummary) for cb in self.trainer.callbacks):
def _configure_model_summary_callback(
self, enable_model_summary: bool, weights_summary: Optional[str] = None
) -> None:
if weights_summary is None:
rank_zero_deprecation(
"Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed"
" in v1.7. Please set `Trainer(enable_model_summary=False)` instead."
)
return
if not enable_model_summary:
return

model_summary_cbs = [type(cb) for cb in self.trainer.callbacks if isinstance(cb, ModelSummary)]
if model_summary_cbs:
rank_zero_info(
f"Trainer already configured with model summary callbacks: {model_summary_cbs}."
" Skipping setting a default `ModelSummary` callback."
)
return
if weights_summary is not None:

if weights_summary == "top":
# special case the default value for weights_summary to preserve backward compatibility
max_depth = 1
else:
rank_zero_deprecation(
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
f"Setting `Trainer(weights_summary={weights_summary})` is deprecated in v1.5 and will be removed"
" in v1.7. Please pass `pytorch_lightning.callbacks.model_summary.ModelSummary` with"
" `max_depth` directly to the Trainer's `callbacks` argument instead."
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
)
if weights_summary not in ModelSummaryMode.supported_types():
raise MisconfigurationException(
f"`weights_summary` can be None, {', '.join(ModelSummaryMode.supported_types())}",
f" but got {weights_summary}",
)
max_depth = ModelSummaryMode.get_max_depth(weights_summary)
if self.trainer._progress_bar_callback is not None and isinstance(
self.trainer._progress_bar_callback, RichProgressBar
):
model_summary = RichModelSummary(max_depth=max_depth)
else:
model_summary = ModelSummary(max_depth=max_depth)
self.trainer.callbacks.append(model_summary)
self.trainer.weights_summary = weights_summary

is_progress_bar_rich = isinstance(self.trainer._progress_bar_callback, RichProgressBar)

if self.trainer._progress_bar_callback is not None and is_progress_bar_rich:
model_summary = RichModelSummary(max_depth=max_depth)
else:
model_summary = ModelSummary(max_depth=max_depth)
self.trainer.callbacks.append(model_summary)
self.trainer._weights_summary = weights_summary

def _configure_swa_callbacks(self):
if not self.trainer._stochastic_weight_avg:
23 changes: 23 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -157,6 +157,7 @@ def __init__(
accelerator: Optional[Union[str, Accelerator]] = None,
sync_batchnorm: bool = False,
precision: Union[int, str] = 32,
enable_model_summary: bool = True,
weights_summary: Optional[str] = "top",
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
weights_save_path: Optional[str] = None,
num_sanity_val_steps: int = 2,
@@ -373,8 +374,16 @@ def __init__(
val_check_interval: How often to check the validation set. Use float to check within a training epoch,
use int to check every n steps (batches).

enable_model_summary: Whether to enable model summarization by default.

weights_summary: Prints a summary of the weights when training begins.

.. deprecated:: v1.5
``weights_summary`` has been deprecated in v1.5 and will be removed in v1.7.
To disable the summary, pass ``enable_model_summary = False`` to the Trainer.
To customize the summary, pass :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary`
directly to the Trainer's ``callbacks`` argument.

weights_save_path: Where to save weights if specified. Will override default_root_dir
for checkpoints only. Use this if for whatever reason you need the checkpoints
stored in a different place than the logs written in `default_root_dir`.
@@ -467,6 +476,9 @@ def __init__(
self.tested_ckpt_path: Optional[str] = None
self.predicted_ckpt_path: Optional[str] = None

# todo: remove in v1.7
self._weights_summary: Optional[str] = None

# init callbacks
# Declare attributes to be set in callback_connector on_trainer_init
self.callback_connector.on_trainer_init(
@@ -478,6 +490,7 @@ def __init__(
process_position,
default_root_dir,
weights_save_path,
enable_model_summary,
weights_summary,
stochastic_weight_avg,
max_time,
@@ -2036,6 +2049,16 @@ def _exit_gracefully_on_signal(self) -> None:
class_name = caller[0].f_locals["self"].__class__.__name__
raise ExitGracefullyException(f"Exiting gracefully on {class_name}:{caller.function}")

@property
def weights_summary(self) -> Optional[str]:
rank_zero_deprecation("`Trainer.weights_summary` is deprecated in v1.5 and will be removed in v1.7.")
return self._weights_summary

@weights_summary.setter
def weights_summary(self, val: Optional[str]) -> None:
rank_zero_deprecation("Setting `Trainer.weights_summary` is deprecated in v1.5 and will be removed in v1.7.")
self._weights_summary = val

"""
Other
"""
4 changes: 2 additions & 2 deletions tests/accelerators/test_multi_nodes_gpu.py
Original file line number Diff line number Diff line change
@@ -54,7 +54,7 @@ def validation_step(self, batch, batch_idx):
limit_train_batches=1,
limit_val_batches=1,
max_epochs=2,
weights_summary=None,
enable_model_summary=False,
accelerator="ddp",
gpus=1,
num_nodes=2,
@@ -101,7 +101,7 @@ def backward(self, loss, optimizer, optimizer_idx):
limit_val_batches=2,
max_epochs=2,
log_every_n_steps=1,
weights_summary=None,
enable_model_summary=False,
accelerator="ddp",
gpus=1,
num_nodes=2,
4 changes: 2 additions & 2 deletions tests/callbacks/test_callback_hook_outputs.py
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ def training_epoch_end(self, outputs) -> None:
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
enable_model_summary=False,
)

assert any(isinstance(c, CB) for c in trainer.callbacks)
@@ -74,7 +74,7 @@ def on_epoch_end(self, trainer, pl_module):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
weights_summary=None,
enable_model_summary=False,
)

trainer.fit(model)
6 changes: 3 additions & 3 deletions tests/callbacks/test_gradient_accumulation_scheduler.py
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@ def test_trainer_accumulate_grad_batches_zero_grad(tmpdir, accumulate_grad_batch
limit_train_batches=20,
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
enable_model_summary=False,
accumulate_grad_batches=accumulate_grad_batches,
)
assert trainer.accumulate_grad_batches == accumulate_grad_batches
@@ -56,7 +56,7 @@ def test_trainer_accumulate_grad_batches_dict_zero_grad(tmpdir, accumulate_grad_
limit_train_batches=10,
limit_val_batches=1,
max_epochs=4,
weights_summary=None,
enable_model_summary=False,
accumulate_grad_batches=accumulate_grad_batches,
)
assert trainer.accumulate_grad_batches == accumulate_grad_batches.get(0, 1)
@@ -74,7 +74,7 @@ def test_trainer_accumulate_grad_batches_with_callback(tmpdir):
limit_train_batches=10,
limit_val_batches=1,
max_epochs=4,
weights_summary=None,
enable_model_summary=False,
callbacks=[GradientAccumulationScheduler({1: 2, 3: 4})],
)
assert trainer.accumulate_grad_batches == 1
8 changes: 4 additions & 4 deletions tests/callbacks/test_lr_monitor.py
Original file line number Diff line number Diff line change
@@ -252,7 +252,7 @@ def configure_optimizers(self):
limit_train_batches=0.5,
callbacks=[lr_monitor],
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
)
trainer.fit(TestModel())
assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ["my_logging_name"]
@@ -273,7 +273,7 @@ def configure_optimizers(self):
limit_train_batches=2,
callbacks=[lr_monitor],
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
)
trainer.fit(TestModel())
assert lr_monitor.lr_sch_names == ["lr-SGD"]
@@ -311,7 +311,7 @@ def configure_optimizers(self):
limit_train_batches=2,
callbacks=[lr_monitor],
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
)

with pytest.raises(
@@ -389,7 +389,7 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int):
limit_train_batches=2,
callbacks=[TestFinetuning(), lr_monitor, Check()],
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
enable_checkpointing=False,
)
model = TestModel()
9 changes: 9 additions & 0 deletions tests/callbacks/test_model_summary.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,15 @@ def test_model_summary_callback_with_weights_summary_none():
trainer = Trainer(weights_summary=None)
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

trainer = Trainer(enable_model_summary=False)
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

trainer = Trainer(enable_model_summary=False, weights_summary="full")
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

trainer = Trainer(enable_model_summary=True, weights_summary=None)
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)


def test_model_summary_callback_with_weights_summary():

2 changes: 1 addition & 1 deletion tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
@@ -547,7 +547,7 @@ def _test_progress_bar_max_val_check_interval(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
max_epochs=1,
weights_summary=None,
enable_model_summary=False,
val_check_interval=val_check_interval,
gpus=world_size,
accelerator="ddp",
Loading