Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

protect progress bar callback #1855

Merged
merged 4 commits into from
May 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- Removed unintended Trainer argument `progress_bar_callback`, the callback should be passed in by `Trainer(callbacks=[...])` instead ([#1855](https://github.com/PyTorchLightning/pytorch-lightning/pull/1855))

### Fixed

- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))
Expand Down
20 changes: 10 additions & 10 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ class TrainerCallbackConfigMixin(ABC):
weights_save_path: str
ckpt_path: str
checkpoint_callback: ModelCheckpoint
progress_bar_refresh_rate: int
process_position: int

@property
@abstractmethod
Expand Down Expand Up @@ -109,20 +107,22 @@ def configure_early_stopping(self, early_stop_callback):
self.early_stop_callback = early_stop_callback
self.enable_early_stop = True

def configure_progress_bar(self):
def configure_progress_bar(self, refresh_rate=1, process_position=0):
progress_bars = [c for c in self.callbacks if isinstance(c, ProgressBarBase)]
if len(progress_bars) > 1:
raise MisconfigurationException(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ques we can allow multiple in future, just check that each is different, meaning another monitor event, frequency, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes agree that would be good. the only reason why we currently have a limit of 1 is because Trainer needs to be able to disable the progress bar temporarily, for example in LRFinder.

'You added multiple progress bar callbacks to the Trainer, but currently only one'
' progress bar is supported.'
)
elif len(progress_bars) == 1:
self.progress_bar_callback = progress_bars[0]
elif self.progress_bar_refresh_rate > 0:
self.progress_bar_callback = ProgressBar(
refresh_rate=self.progress_bar_refresh_rate,
process_position=self.process_position,
progress_bar_callback = progress_bars[0]
elif refresh_rate > 0:
progress_bar_callback = ProgressBar(
refresh_rate=refresh_rate,
process_position=process_position,
)
self.callbacks.append(self.progress_bar_callback)
self.callbacks.append(progress_bar_callback)
else:
self.progress_bar_callback = None
progress_bar_callback = None

return progress_bar_callback
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def show_progress_bar(self):
"""Back compatibility, will be removed in v0.9.0"""
rank_zero_warn("Argument `show_progress_bar` is now set by `progress_bar_refresh_rate` since v0.7.2"
" and this method will be removed in v0.9.0", DeprecationWarning)
return self.progress_bar_refresh_rate >= 1
return self.progress_bar_callback and self.progress_bar_callback.refresh_rate >= 1
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

@show_progress_bar.setter
def show_progress_bar(self, tf):
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/trainer/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,9 @@ def __lr_finder_dump_params(self, model):
'callbacks': self.callbacks,
'logger': self.logger,
'max_steps': self.max_steps,
'progress_bar_refresh_rate': self.progress_bar_refresh_rate,
'checkpoint_callback': self.checkpoint_callback,
'early_stop_callback': self.early_stop_callback,
'enable_early_stop': self.enable_early_stop,
'progress_bar_callback': self.progress_bar_callback,
'configure_optimizers': model.configure_optimizers,
}

Expand All @@ -211,11 +209,9 @@ def __lr_finder_restore_params(self, model):
self.logger = self.__dumped_params['logger']
self.callbacks = self.__dumped_params['callbacks']
self.max_steps = self.__dumped_params['max_steps']
self.progress_bar_refresh_rate = self.__dumped_params['progress_bar_refresh_rate']
self.checkpoint_callback = self.__dumped_params['checkpoint_callback']
self.early_stop_callback = self.__dumped_params['early_stop_callback']
self.enable_early_stop = self.__dumped_params['enable_early_stop']
self.progress_bar_callback = self.__dumped_params['progress_bar_callback']
model.configure_optimizers = self.__dumped_params['configure_optimizers']
del self.__dumped_params

Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def __init__(
reload_dataloaders_every_epoch: bool = False,
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
progress_bar_callback: Optional[Union[ProgressBarBase, bool]] = True,
terminate_on_nan: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
num_tpu_cores: Optional[int] = None, # backward compatible, todo: remove in v0.9.0
Expand Down Expand Up @@ -364,7 +363,6 @@ def __init__(
rank_zero_warn("num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it.")
self.num_processes = num_processes

self.process_position = process_position
self.weights_summary = weights_summary

self.max_epochs = max_epochs
Expand Down Expand Up @@ -505,9 +503,7 @@ def __init__(
if show_progress_bar is not None:
self.show_progress_bar = show_progress_bar

self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.progress_bar_callback = progress_bar_callback
self.configure_progress_bar()
self._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)

# logging
self.log_save_interval = log_save_interval
Expand Down Expand Up @@ -660,7 +656,6 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
'min_steps': None,
...
'profiler': None,
'progress_bar_callback': True,
'progress_bar_refresh_rate': 1,
...}

Expand Down Expand Up @@ -755,6 +750,10 @@ def num_gpus(self) -> int:
def data_parallel(self) -> bool:
return self.use_dp or self.use_ddp or self.use_ddp2

@property
def progress_bar_callback(self):
return self._progress_bar_callback
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure how much it protects as it handles pointer to the same object so edit in the return will appear in the original one...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I think that's fine. the goal was to make the reference read-only so that you can't change the reference to anything other than a progress bar.


@property
def progress_bar_dict(self) -> dict:
""" Read-only for progress bar metrics. """
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def on_test_batch_end(self, trainer, pl_module):
num_sanity_val_steps=2,
max_epochs=3,
)
assert trainer.progress_bar_callback.refresh_rate == refresh_rate != trainer.progress_bar_refresh_rate
assert trainer.progress_bar_callback.refresh_rate == refresh_rate

trainer.fit(model)
assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches
Expand Down
5 changes: 2 additions & 3 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ def test_trainer_reset_correctly(tmpdir):
)

changed_attributes = ['callbacks', 'logger', 'max_steps', 'auto_lr_find',
'progress_bar_refresh_rate', 'early_stop_callback',
'accumulate_grad_batches', 'enable_early_stop',
'checkpoint_callback']
'early_stop_callback', 'accumulate_grad_batches',
'enable_early_stop', 'checkpoint_callback']
attributes_before = {}
for ca in changed_attributes:
attributes_before[ca] = getattr(trainer, ca)
Expand Down