From ee63840e8b2449f651a96fdcb0468f8a7283c576 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 13 Oct 2021 04:50:54 -0700 Subject: [PATCH] Add `enable_model_summary` flag and deprecate `weights_summary` (#9699) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kaushik B --- CHANGELOG.md | 6 +++ benchmarks/test_basic_parity.py | 2 +- docs/source/common/debugging.rst | 12 +++-- docs/source/common/trainer.rst | 24 +++++++++ pl_examples/bug_report_model.py | 2 +- .../computer_vision_fine_tuning.py | 2 +- .../trainer/connectors/callback_connector.py | 51 ++++++++++++++----- pytorch_lightning/trainer/trainer.py | 23 +++++++++ tests/accelerators/test_multi_nodes_gpu.py | 4 +- tests/callbacks/test_callback_hook_outputs.py | 4 +- .../test_gradient_accumulation_scheduler.py | 6 +-- tests/callbacks/test_lr_monitor.py | 8 +-- tests/callbacks/test_model_summary.py | 9 ++++ tests/callbacks/test_progress_bar.py | 2 +- tests/callbacks/test_pruning.py | 6 +-- .../test_checkpoint_callback_frequency.py | 6 +-- tests/checkpointing/test_model_checkpoint.py | 18 +++---- tests/core/test_datamodules.py | 8 +-- tests/core/test_lightning_optimizer.py | 12 ++--- tests/deprecated_api/test_remove_1-7.py | 15 ++++++ tests/loggers/test_base.py | 6 +-- tests/loops/batch/test_truncated_bptt.py | 6 +-- tests/loops/test_evaluation_loop.py | 8 +-- tests/loops/test_evaluation_loop_flow.py | 8 +-- tests/loops/test_flow_warnings.py | 2 +- tests/loops/test_loops.py | 2 +- tests/loops/test_training_loop.py | 2 +- tests/loops/test_training_loop_flow_dict.py | 8 +-- tests/loops/test_training_loop_flow_scalar.py | 12 ++--- tests/models/test_cpu.py | 1 - tests/models/test_hooks.py | 12 ++--- tests/models/test_horovod.py | 2 +- .../connectors/test_callback_connector.py | 5 +- .../test_multiple_eval_dataloaders.py | 6 +-- tests/trainer/flags/test_min_max_epochs.py | 2 +- tests/trainer/flags/test_overfit_batches.py | 4 +- .../logging_/test_distributed_logging.py | 4 +- .../logging_/test_eval_loop_logging.py | 16 +++--- .../logging_/test_train_loop_logging.py | 18 +++---- .../optimization/test_manual_optimization.py | 12 ++--- .../optimization/test_multiple_optimizers.py | 8 +-- tests/trainer/optimization/test_optimizers.py | 2 +- tests/trainer/test_dataloaders.py | 4 +- tests/trainer/test_trainer.py | 6 +-- tests/utilities/test_auto_restart.py | 2 +- tests/utilities/test_cli.py | 4 +- 46 files changed, 248 insertions(+), 134 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bc6c42841659..081faa325ab4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -181,6 +181,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848)) +- Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699)) + + ### Changed - Module imports are now catching `ModuleNotFoundError` instead of `ImportError` ([#9867](https://github.com/PyTorchLightning/pytorch-lightning/pull/9867)) @@ -344,6 +347,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `pytorch_lightning.core.decorators.parameter_validation` in favor of `pytorch_lightning.utilities.parameter_tying.set_shared_parameters` ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525)) +- Deprecated passing `weights_summary` to the `Trainer` constructor in favor of adding the `ModelSummary` callback with `max_depth` directly to the list of callbacks ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699)) + + ### Removed - Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/)) diff --git a/benchmarks/test_basic_parity.py b/benchmarks/test_basic_parity.py index e9442dd26e65b..2144be39394cb 100644 --- a/benchmarks/test_basic_parity.py +++ b/benchmarks/test_basic_parity.py @@ -159,7 +159,7 @@ def lightning_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10): # as the first run is skipped, no need to run it long max_epochs=num_epochs if idx > 0 else 1, enable_progress_bar=False, - weights_summary=None, + enable_model_summary=False, gpus=1 if device_type == "cuda" else 0, checkpoint_callback=False, logger=False, diff --git a/docs/source/common/debugging.rst b/docs/source/common/debugging.rst index 7a11863c0e1bf..6e5a721dd092a 100644 --- a/docs/source/common/debugging.rst +++ b/docs/source/common/debugging.rst @@ -95,11 +95,14 @@ Print a summary of your LightningModule --------------------------------------- Whenever the ``.fit()`` function gets called, the Trainer will print the weights summary for the LightningModule. By default it only prints the top-level modules. If you want to show all submodules in your network, use the -`'full'` option: +``max_depth`` option: .. testcode:: - trainer = Trainer(weights_summary="full") + from pytorch_lightning.callbacks import ModelSummary + + trainer = Trainer(callbacks=[ModelSummary(max_depth=-1)]) + You can also display the intermediate input- and output sizes of all your layers by setting the ``example_input_array`` attribute in your LightningModule. It will print a table like this @@ -115,8 +118,9 @@ You can also display the intermediate input- and output sizes of all your layers when you call ``.fit()`` on the Trainer. This can help you find bugs in the composition of your layers. See Also: - - :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_summary` Trainer argument - - :class:`~pytorch_lightning.core.memory.ModelSummary` + - :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary` + - :func:`~pytorch_lightning.utilities.model_summary.summarize` + - :class:`~pytorch_lightning.utilities.model_summary.ModelSummary` ---------------- diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 51726608ed90d..24f08b2f95091 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -1589,6 +1589,11 @@ Example:: weights_summary ^^^^^^^^^^^^^^^ +.. warning:: `weights_summary` is deprecated in v1.5 and will be removed in v1.7. Please pass :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary` + directly to the Trainer's ``callbacks`` argument instead. To disable the model summary, + pass ``enable_model_summary = False`` to the Trainer. + + .. raw:: html