Skip to content

Commit

Permalink
protect progress bar callback (#1855)
Browse files Browse the repository at this point in the history
* wip protected progress bar settings

* remove callback attr from LRfinder

* whitespace

* changelog
  • Loading branch information
awaelchli authored May 25, 2020
1 parent 112dd5c commit 8ca8336
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 25 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- Removed unintended Trainer argument `progress_bar_callback`, the callback should be passed in by `Trainer(callbacks=[...])` instead ([#1855](https://github.com/PyTorchLightning/pytorch-lightning/pull/1855))

### Fixed

- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))
Expand Down
20 changes: 10 additions & 10 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ class TrainerCallbackConfigMixin(ABC):
weights_save_path: str
ckpt_path: str
checkpoint_callback: ModelCheckpoint
progress_bar_refresh_rate: int
process_position: int

@property
@abstractmethod
Expand Down Expand Up @@ -109,20 +107,22 @@ def configure_early_stopping(self, early_stop_callback):
self.early_stop_callback = early_stop_callback
self.enable_early_stop = True

def configure_progress_bar(self):
def configure_progress_bar(self, refresh_rate=1, process_position=0):
progress_bars = [c for c in self.callbacks if isinstance(c, ProgressBarBase)]
if len(progress_bars) > 1:
raise MisconfigurationException(
'You added multiple progress bar callbacks to the Trainer, but currently only one'
' progress bar is supported.'
)
elif len(progress_bars) == 1:
self.progress_bar_callback = progress_bars[0]
elif self.progress_bar_refresh_rate > 0:
self.progress_bar_callback = ProgressBar(
refresh_rate=self.progress_bar_refresh_rate,
process_position=self.process_position,
progress_bar_callback = progress_bars[0]
elif refresh_rate > 0:
progress_bar_callback = ProgressBar(
refresh_rate=refresh_rate,
process_position=process_position,
)
self.callbacks.append(self.progress_bar_callback)
self.callbacks.append(progress_bar_callback)
else:
self.progress_bar_callback = None
progress_bar_callback = None

return progress_bar_callback
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def show_progress_bar(self):
"""Back compatibility, will be removed in v0.9.0"""
rank_zero_warn("Argument `show_progress_bar` is now set by `progress_bar_refresh_rate` since v0.7.2"
" and this method will be removed in v0.9.0", DeprecationWarning)
return self.progress_bar_refresh_rate >= 1
return self.progress_bar_callback and self.progress_bar_callback.refresh_rate >= 1

@show_progress_bar.setter
def show_progress_bar(self, tf):
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/trainer/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,9 @@ def __lr_finder_dump_params(self, model):
'callbacks': self.callbacks,
'logger': self.logger,
'max_steps': self.max_steps,
'progress_bar_refresh_rate': self.progress_bar_refresh_rate,
'checkpoint_callback': self.checkpoint_callback,
'early_stop_callback': self.early_stop_callback,
'enable_early_stop': self.enable_early_stop,
'progress_bar_callback': self.progress_bar_callback,
'configure_optimizers': model.configure_optimizers,
}

Expand All @@ -211,11 +209,9 @@ def __lr_finder_restore_params(self, model):
self.logger = self.__dumped_params['logger']
self.callbacks = self.__dumped_params['callbacks']
self.max_steps = self.__dumped_params['max_steps']
self.progress_bar_refresh_rate = self.__dumped_params['progress_bar_refresh_rate']
self.checkpoint_callback = self.__dumped_params['checkpoint_callback']
self.early_stop_callback = self.__dumped_params['early_stop_callback']
self.enable_early_stop = self.__dumped_params['enable_early_stop']
self.progress_bar_callback = self.__dumped_params['progress_bar_callback']
model.configure_optimizers = self.__dumped_params['configure_optimizers']
del self.__dumped_params

Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def __init__(
reload_dataloaders_every_epoch: bool = False,
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
progress_bar_callback: Optional[Union[ProgressBarBase, bool]] = True,
terminate_on_nan: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
num_tpu_cores: Optional[int] = None, # backward compatible, todo: remove in v0.9.0
Expand Down Expand Up @@ -364,7 +363,6 @@ def __init__(
rank_zero_warn("num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it.")
self.num_processes = num_processes

self.process_position = process_position
self.weights_summary = weights_summary

self.max_epochs = max_epochs
Expand Down Expand Up @@ -506,9 +504,7 @@ def __init__(
if show_progress_bar is not None:
self.show_progress_bar = show_progress_bar

self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.progress_bar_callback = progress_bar_callback
self.configure_progress_bar()
self._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)

# logging
self.log_save_interval = log_save_interval
Expand Down Expand Up @@ -661,7 +657,6 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
'min_steps': None,
...
'profiler': None,
'progress_bar_callback': True,
'progress_bar_refresh_rate': 1,
...}
Expand Down Expand Up @@ -756,6 +751,10 @@ def num_gpus(self) -> int:
def data_parallel(self) -> bool:
return self.use_dp or self.use_ddp or self.use_ddp2

@property
def progress_bar_callback(self):
return self._progress_bar_callback

@property
def progress_bar_dict(self) -> dict:
""" Read-only for progress bar metrics. """
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def on_test_batch_end(self, trainer, pl_module):
num_sanity_val_steps=2,
max_epochs=3,
)
assert trainer.progress_bar_callback.refresh_rate == refresh_rate != trainer.progress_bar_refresh_rate
assert trainer.progress_bar_callback.refresh_rate == refresh_rate

trainer.fit(model)
assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches
Expand Down
5 changes: 2 additions & 3 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ def test_trainer_reset_correctly(tmpdir):
)

changed_attributes = ['callbacks', 'logger', 'max_steps', 'auto_lr_find',
'progress_bar_refresh_rate', 'early_stop_callback',
'accumulate_grad_batches', 'enable_early_stop',
'checkpoint_callback']
'early_stop_callback', 'accumulate_grad_batches',
'enable_early_stop', 'checkpoint_callback']
attributes_before = {}
for ca in changed_attributes:
attributes_before[ca] = getattr(trainer, ca)
Expand Down

0 comments on commit 8ca8336

Please sign in to comment.