Skip to content

Commit

Permalink
Feature/5275 clean progress bar print (#5470)
Browse files Browse the repository at this point in the history
* Trainer.test should return only test metrics (#5214)

* resolve bug

* merge tests

* Fix metric state reset (#5273)

* Fix metric state reset

* Fix test

* Improve formatting

Co-authored-by: Ananya Harsh Jha <ananya@pytorchlightning.ai>

* print() method added to ProgressBar

* printing alongside progress bar added to LightningModule.print()

* LightningModule.print() method documentation updated

* ProgressBarBase.print() stub added

* stub

* add progress bar tests

* fix isort

* Progress Callback fixes

* test_metric.py duplicate DummyList removed

* PEP and isort fixes

* CHANGELOG updated

* test_progress_bar_print win linesep fix

* test_progress_bar.py remove whitespaces

* Update CHANGELOG.md

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Tadej Svetina <tadej.svetina@gmail.com>
Co-authored-by: Ananya Harsh Jha <ananya@pytorchlightning.ai>
Co-authored-by: Alexander Snorkin <Alexander.Snorkin@acronis.com>
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
8 people authored Feb 22, 2021
1 parent 57215b7 commit 423ecf9
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added DeepSpeed integration ([#5954](https://github.com/PyTorchLightning/pytorch-lightning/pull/5954),
[#6042](https://github.com/PyTorchLightning/pytorch-lightning/pull/6042))

- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))


### Changed

- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
Expand Down
24 changes: 24 additions & 0 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
"""
import importlib
import io
import os
import sys

# check if ipywidgets is installed before importing tqdm.auto
Expand Down Expand Up @@ -187,6 +189,12 @@ def enable(self):
"""
raise NotImplementedError

def print(self, *args, **kwargs):
"""
You should provide a way to print without breaking the progress bar.
"""
print(*args, **kwargs)

def on_init_end(self, trainer):
self._trainer = trainer

Expand Down Expand Up @@ -451,6 +459,22 @@ def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, da
def on_predict_end(self, trainer, pl_module):
self.predict_progress_bar.close()

def print(
self, *args, sep: str = ' ', end: str = os.linesep, file: Optional[io.TextIOBase] = None, nolock: bool = False
):
active_progress_bar = None

if not self.main_progress_bar.disable:
active_progress_bar = self.main_progress_bar
elif not self.val_progress_bar.disable:
active_progress_bar = self.val_progress_bar
elif not self.test_progress_bar.disable:
active_progress_bar = self.test_progress_bar

if active_progress_bar is not None:
s = sep.join(map(str, args))
active_progress_bar.write(s, end=end, file=file, nolock=nolock)

def _should_update(self, current, total):
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

Expand Down
10 changes: 7 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ def print(self, *args, **kwargs) -> None:
Prints only from process 0. Use this in any distributed mode to log only once.
Args:
*args: The thing to print. Will be passed to Python's built-in print function.
**kwargs: Will be passed to Python's built-in print function.
*args: The thing to print. The same as for Python's built-in print function.
**kwargs: The same as for Python's built-in print function.
Example::
Expand All @@ -199,7 +199,11 @@ def forward(self, x):
"""
if self.trainer.is_global_zero:
print(*args, **kwargs)
progress_bar = self.trainer.progress_bar_callback
if progress_bar is not None and progress_bar.is_enabled:
progress_bar.print(*args, **kwargs)
else:
print(*args, **kwargs)

def log(
self,
Expand Down
69 changes: 68 additions & 1 deletion tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from unittest import mock
from unittest.mock import call, Mock
from unittest.mock import ANY, call, Mock

import pytest
import torch
Expand Down Expand Up @@ -381,3 +382,69 @@ def training_step(self, batch, batch_idx):
def test_tqdm_format_num(input_num, expected):
""" Check that the specialized tqdm.format_num appends 0 to floats and strings """
assert tqdm.format_num(input_num) == expected


class PrintModel(BoringModel):

def training_step(self, *args, **kwargs):
self.print("training_step", end="")
return super().training_step(*args, **kwargs)

def validation_step(self, *args, **kwargs):
self.print("validation_step", file=sys.stderr)
return super().validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs):
self.print("test_step")
return super().test_step(*args, **kwargs)


@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write")
def test_progress_bar_print(tqdm_write, tmpdir):
""" Test that printing in the LightningModule redirects arguments to the progress bar. """
model = PrintModel()
bar = ProgressBar()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
max_steps=1,
callbacks=[bar],
)
trainer.fit(model)
trainer.test(model)
assert tqdm_write.call_count == 3
assert tqdm_write.call_args_list == [
call("training_step", end="", file=None, nolock=False),
call("validation_step", end=os.linesep, file=sys.stderr, nolock=False),
call("test_step", end=os.linesep, file=None, nolock=False),
]


@mock.patch('builtins.print')
@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write")
def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir):
""" Test that printing in LightningModule goes through built-in print functin when progress bar is disabled. """
model = PrintModel()
bar = ProgressBar()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
max_steps=1,
callbacks=[bar],
)
bar.disable()
trainer.fit(model)
trainer.test(model)

mock_print.assert_has_calls([
call("training_step", end=""),
call("validation_step", file=ANY),
call("test_step"),
])
tqdm_write.assert_not_called()

0 comments on commit 423ecf9

Please sign in to comment.