Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ananthsub committed Sep 30, 2021
1 parent 2cec4b0 commit d83ed9d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
4 changes: 1 addition & 3 deletions docs/source/common/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@ By default it only prints the top-level modules. If you want to show all submodu
`'max_depth'` option:

.. testcode::
from pytorch_lightning.callbacks import ModelSummary

trainer = Trainer(callbacks=[ModelSummary(max_depth=-1)])
trainer = Trainer(weights_summary="full")

You can also display the intermediate input- and output sizes of all your layers by setting the
``example_input_array`` attribute in your LightningModule. It will print a table like this
Expand Down
31 changes: 19 additions & 12 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def _configure_accumulated_gradients(
if grad_accum_callback:
if accumulate_grad_batches is not None:
raise MisconfigurationException(
"You have set both `accumulate_grad_batches` and passed an instance of "
"`GradientAccumulationScheduler` inside callbacks. Either remove `accumulate_grad_batches` "
"from trainer or remove `GradientAccumulationScheduler` from callbacks list."
"You have set both `accumulate_grad_batches` and passed an instance of"
" `GradientAccumulationScheduler` inside callbacks. Either remove `accumulate_grad_batches`"
" from trainer or remove `GradientAccumulationScheduler` from callbacks list."
)
grad_accum_callback = grad_accum_callback[0]
else:
Expand Down Expand Up @@ -162,19 +162,27 @@ def _configure_checkpoint_callbacks(self, checkpoint_callback: bool) -> None:
def _configure_model_summary_callback(
self, enable_model_summary: bool, weights_summary: Optional[str] = None
) -> None:
if not enable_model_summary:
return
if any(isinstance(cb, ModelSummary) for cb in self.trainer.callbacks):
return
if weights_summary is None:
rank_zero_deprecation(
"Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed"
" in v1.7. Please set `Trainer(enable_model_summary=False) instead."
)
return
# Prior default in the Trainer for `weights_summary` which we explicitly check here
# to preserve backwards compatibility
if weights_summary != "top":
if not enable_model_summary:
return

model_summary_cbs = [type(cb) for cb in self.trainer.callbacks if isinstance(cb, ModelSummary)]
if model_summary_cbs:
rank_zero_info(
f"Trainer already configured with model summary callbacks: {model_summary_cbs}."
" Skipping setting a default `ModelSummary` callback."
)
return

if weights_summary == "top":
# special case the default value for weights_summary to preserve backward compatibility
max_depth = 1
else:
rank_zero_deprecation(
f"Setting `Trainer(weights_summary={weights_summary})` is deprecated in v1.5 and will be removed"
" in v1.7. Please pass `pytorch_lightning.callbacks.model_summary.ModelSummary` with"
Expand All @@ -186,8 +194,7 @@ def _configure_model_summary_callback(
f" but got {weights_summary}",
)
max_depth = ModelSummaryMode.get_max_depth(weights_summary)
else:
max_depth = 1

if self.trainer._progress_bar_callback is not None and isinstance(
self.trainer._progress_bar_callback, RichProgressBar
):
Expand Down

0 comments on commit d83ed9d

Please sign in to comment.