Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add enable_model_summary flag and deprecate weights_summary #9699

Merged
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enabled automatic parameters tying for TPUs ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))


- Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))


### Changed

- `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)).
Expand Down Expand Up @@ -308,12 +311,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated Accelerator collective API `barrier`, `broadcast`, and `all_gather`, call `TrainingTypePlugin` collective API directly ([#9677](https://github.com/PyTorchLightning/pytorch-lightning/pull/9677))


<<<<<<< HEAD
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
- Deprecated the `LightningModule.on_post_move_to_device` method ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))


- Deprecated `pytorch_lightning.core.decorators.parameter_validation` in favor of `pytorch_lightning.utilities.parameter_tying.set_shared_parameters` ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))


- Deprecated passing `weights_summary` to the `Trainer` constructor in favor of adding the `ModelSummary` callback with `max_depth` directly to the list of callbacks ([#9616](https://github.com/PyTorchLightning/pytorch-lightning/pull/9616))
=======
- Deprecated passing `weights_summary` to the `Trainer` constructor in favor of adding the `ModelSummary` callback with `max_depth` directly to the list of callbacks ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))
>>>>>>> 3a1c155ef (Update CHANGELOG.md)


ananthsub marked this conversation as resolved.
Show resolved Hide resolved
### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/test_basic_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def lightning_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):
# as the first run is skipped, no need to run it long
max_epochs=num_epochs if idx > 0 else 1,
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
gpus=1 if device_type == "cuda" else 0,
checkpoint_callback=False,
logger=False,
Expand Down
12 changes: 8 additions & 4 deletions docs/source/common/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,14 @@ Print a summary of your LightningModule
---------------------------------------
Whenever the ``.fit()`` function gets called, the Trainer will print the weights summary for the LightningModule.
By default it only prints the top-level modules. If you want to show all submodules in your network, use the
`'full'` option:
``max_depth`` option:

.. testcode::

trainer = Trainer(weights_summary="full")
from pytorch_lightning.callbacks import ModelSummary

trainer = Trainer(callbacks=[ModelSummary(max_depth=-1)])


You can also display the intermediate input- and output sizes of all your layers by setting the
``example_input_array`` attribute in your LightningModule. It will print a table like this
Expand All @@ -115,8 +118,9 @@ You can also display the intermediate input- and output sizes of all your layers
when you call ``.fit()`` on the Trainer. This can help you find bugs in the composition of your layers.

See Also:
- :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_summary` Trainer argument
- :class:`~pytorch_lightning.core.memory.ModelSummary`
- :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary`
- :func:`~pytorch_lightning.utilities.model_summary.summarize`
- :class:`~pytorch_lightning.utilities.model_summary.ModelSummary`

----------------

Expand Down
24 changes: 24 additions & 0 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1589,6 +1589,11 @@ Example::
weights_summary
^^^^^^^^^^^^^^^

.. warning:: `weights_summary` is deprecated in v1.5 and will be removed in v1.7. Please pass :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary`
directly to the Trainer's ``callbacks`` argument instead. To disable the model summary,
pass ``enable_model_summary = False`` to the Trainer.


.. raw:: html

<video width="50%" max-width="400px" controls
Expand All @@ -1611,6 +1616,25 @@ Options: 'full', 'top', None.
# don't print a summary
trainer = Trainer(weights_summary=None)


enable_model_summary
^^^^^^^^^^^^^^^^^^^^

Whether to enable or disable the model summarization. Defaults to True.

.. testcode::

# default used by the Trainer
trainer = Trainer(enable_model_summary=True)

# disable summarization
trainer = Trainer(enable_model_summary=False)
ananthsub marked this conversation as resolved.
Show resolved Hide resolved

# enable custom summarization
from pytorch_lightning.callbacks import ModelSummary

trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])

-----

Trainer class API
Expand Down
2 changes: 1 addition & 1 deletion pl_examples/bug_report_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def run():
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
weights_summary=None,
enable_model_summary=False,
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
trainer.test(model, dataloaders=test_data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def add_arguments_to_parser(self, parser):
parser.set_defaults(
{
"trainer.max_epochs": 15,
"trainer.weights_summary": None,
"trainer.enable_model_summary": False,
"trainer.num_sanity_val_steps": 0,
}
)
Expand Down
57 changes: 42 additions & 15 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def on_trainer_init(
process_position: int,
default_root_dir: Optional[str],
weights_save_path: Optional[str],
enable_model_summary: bool,
weights_summary: Optional[str],
stochastic_weight_avg: bool,
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
Expand Down Expand Up @@ -100,7 +101,7 @@ def on_trainer_init(
self.trainer._progress_bar_callback = None

# configure the ModelSummary callback
self._configure_model_summary_callback(weights_summary)
self._configure_model_summary_callback(enable_model_summary, weights_summary)

# accumulated grads
self._configure_accumulated_gradients(accumulate_grad_batches)
Expand All @@ -117,9 +118,9 @@ def _configure_accumulated_gradients(
if grad_accum_callback:
if accumulate_grad_batches is not None:
raise MisconfigurationException(
"You have set both `accumulate_grad_batches` and passed an instance of "
"`GradientAccumulationScheduler` inside callbacks. Either remove `accumulate_grad_batches` "
"from trainer or remove `GradientAccumulationScheduler` from callbacks list."
"You have set both `accumulate_grad_batches` and passed an instance of"
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
" `GradientAccumulationScheduler` inside callbacks. Either remove `accumulate_grad_batches`"
" from trainer or remove `GradientAccumulationScheduler` from callbacks list."
)
grad_accum_callback = grad_accum_callback[0]
else:
Expand Down Expand Up @@ -158,24 +159,50 @@ def _configure_checkpoint_callbacks(self, checkpoint_callback: bool) -> None:
if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
self.trainer.callbacks.append(ModelCheckpoint())

def _configure_model_summary_callback(self, weights_summary: Optional[str] = None) -> None:
if any(isinstance(cb, ModelSummary) for cb in self.trainer.callbacks):
def _configure_model_summary_callback(
self, enable_model_summary: bool, weights_summary: Optional[str] = None
) -> None:
if weights_summary is None:
rank_zero_deprecation(
"Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed"
" in v1.7. Please set `Trainer(enable_model_summary=False) instead."
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
)
return
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
if not enable_model_summary:
return

model_summary_cbs = [type(cb) for cb in self.trainer.callbacks if isinstance(cb, ModelSummary)]
if model_summary_cbs:
rank_zero_info(
f"Trainer already configured with model summary callbacks: {model_summary_cbs}."
" Skipping setting a default `ModelSummary` callback."
)
return
if weights_summary is not None:

if weights_summary == "top":
# special case the default value for weights_summary to preserve backward compatibility
max_depth = 1
else:
rank_zero_deprecation(
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
f"Setting `Trainer(weights_summary={weights_summary})` is deprecated in v1.5 and will be removed"
" in v1.7. Please pass `pytorch_lightning.callbacks.model_summary.ModelSummary` with"
" `max_depth` directly to the Trainer's `callbacks` argument instead."
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
)
if weights_summary not in ModelSummaryMode.supported_types():
raise MisconfigurationException(
f"`weights_summary` can be None, {', '.join(ModelSummaryMode.supported_types())}",
f" but got {weights_summary}",
)
max_depth = ModelSummaryMode.get_max_depth(weights_summary)
if self.trainer._progress_bar_callback is not None and isinstance(
self.trainer._progress_bar_callback, RichProgressBar
):
model_summary = RichModelSummary(max_depth=max_depth)
else:
model_summary = ModelSummary(max_depth=max_depth)
self.trainer.callbacks.append(model_summary)
self.trainer.weights_summary = weights_summary

if self.trainer._progress_bar_callback is not None and isinstance(
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
self.trainer._progress_bar_callback, RichProgressBar
):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
model_summary = RichModelSummary(max_depth=max_depth)
else:
model_summary = ModelSummary(max_depth=max_depth)
self.trainer.callbacks.append(model_summary)
self.trainer._weights_summary = weights_summary

def _configure_swa_callbacks(self):
if not self.trainer._stochastic_weight_avg:
Expand Down
23 changes: 23 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(
accelerator: Optional[Union[str, Accelerator]] = None,
sync_batchnorm: bool = False,
precision: Union[int, str] = 32,
enable_model_summary: bool = True,
weights_summary: Optional[str] = "top",
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
weights_save_path: Optional[str] = None,
num_sanity_val_steps: int = 2,
Expand Down Expand Up @@ -357,8 +358,16 @@ def __init__(
val_check_interval: How often to check the validation set. Use float to check within a training epoch,
use int to check every n steps (batches).

enable_model_summary: Whether to enable model summarization by default.

weights_summary: Prints a summary of the weights when training begins.

.. deprecated:: v1.5
``weights_summary`` has been deprecated in v1.5 and will be removed in v1.7.
To disable the summary, pass ``enable_model_summary = False`` to the Trainer.
To customize the summary, pass :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary`
directly to the Trainer's ``callbacks`` argument.

weights_save_path: Where to save weights if specified. Will override default_root_dir
for checkpoints only. Use this if for whatever reason you need the checkpoints
stored in a different place than the logs written in `default_root_dir`.
Expand Down Expand Up @@ -451,6 +460,9 @@ def __init__(
self.tested_ckpt_path: Optional[str] = None
self.predicted_ckpt_path: Optional[str] = None

# todo: remove in v1.7
self._weights_summary: Optional[str] = None

# init callbacks
# Declare attributes to be set in callback_connector on_trainer_init
self.callback_connector.on_trainer_init(
Expand All @@ -461,6 +473,7 @@ def __init__(
process_position,
default_root_dir,
weights_save_path,
enable_model_summary,
weights_summary,
stochastic_weight_avg,
max_time,
Expand Down Expand Up @@ -2016,6 +2029,16 @@ def _exit_gracefully_on_signal(self) -> None:
class_name = caller[0].f_locals["self"].__class__.__name__
raise ExitGracefullyException(f"Exiting gracefully on {class_name}:{caller.function}")

@property
def weights_summary(self) -> Optional[str]:
rank_zero_deprecation("`Trainer.weights_summary` is deprecated in v1.5 and will be removed in v1.7.")
return self._weights_summary

@weights_summary.setter
def weights_summary(self, val: Optional[str]) -> None:
rank_zero_deprecation("Setting `Trainer.weights_summary` is deprecated in v1.5 and will be removed in v1.7.")
self._weights_summary = val

"""
Other
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/accelerators/test_multi_nodes_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def validation_step(self, batch, batch_idx):
limit_train_batches=1,
limit_val_batches=1,
max_epochs=2,
weights_summary=None,
enable_model_summary=False,
accelerator="ddp",
gpus=1,
num_nodes=2,
Expand Down Expand Up @@ -101,7 +101,7 @@ def backward(self, loss, optimizer, optimizer_idx):
limit_val_batches=2,
max_epochs=2,
log_every_n_steps=1,
weights_summary=None,
enable_model_summary=False,
accelerator="ddp",
gpus=1,
num_nodes=2,
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_callback_hook_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def training_epoch_end(self, outputs) -> None:
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
enable_model_summary=False,
)

assert any(isinstance(c, CB) for c in trainer.callbacks)
Expand All @@ -74,7 +74,7 @@ def on_epoch_end(self, trainer, pl_module):
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
weights_summary=None,
enable_model_summary=False,
)

trainer.fit(model)
8 changes: 4 additions & 4 deletions tests/callbacks/test_lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def configure_optimizers(self):
limit_train_batches=0.5,
callbacks=[lr_monitor],
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
)
trainer.fit(TestModel())
assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ["my_logging_name"]
Expand All @@ -273,7 +273,7 @@ def configure_optimizers(self):
limit_train_batches=2,
callbacks=[lr_monitor],
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
)
trainer.fit(TestModel())
assert lr_monitor.lr_sch_names == ["lr-SGD"]
Expand Down Expand Up @@ -311,7 +311,7 @@ def configure_optimizers(self):
limit_train_batches=2,
callbacks=[lr_monitor],
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
)

with pytest.raises(
Expand Down Expand Up @@ -389,7 +389,7 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int):
limit_train_batches=2,
callbacks=[TestFinetuning(), lr_monitor, Check()],
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
checkpoint_callback=False,
)
model = TestModel()
Expand Down
9 changes: 9 additions & 0 deletions tests/callbacks/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ def test_model_summary_callback_with_weights_summary_none():
trainer = Trainer(weights_summary=None)
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

trainer = Trainer(enable_model_summary=False)
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

trainer = Trainer(enable_model_summary=False, weights_summary="full")
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

trainer = Trainer(enable_model_summary=True, weights_summary=None)
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)


def test_model_summary_callback_with_weights_summary():

Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def _test_progress_bar_max_val_check_interval(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
max_epochs=1,
weights_summary=None,
enable_model_summary=False,
val_check_interval=val_check_interval,
gpus=world_size,
accelerator="ddp",
Expand Down
6 changes: 3 additions & 3 deletions tests/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def train_with_pruning_callback(
trainer = Trainer(
default_root_dir=tmpdir,
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
checkpoint_callback=False,
logger=False,
limit_train_batches=10,
Expand Down Expand Up @@ -226,7 +226,7 @@ def apply_lottery_ticket_hypothesis(self):
trainer = Trainer(
default_root_dir=tmpdir,
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
checkpoint_callback=False,
logger=False,
limit_train_batches=10,
Expand All @@ -253,7 +253,7 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool
trainer = Trainer(
default_root_dir=tmpdir,
enable_progress_bar=False,
weights_summary=None,
enable_model_summary=False,
checkpoint_callback=False,
logger=False,
limit_train_batches=10,
Expand Down
Loading