Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added state attributes for tqdm logger #2162

Merged
merged 5 commits into from
Aug 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions ignite/contrib/handlers/tqdm_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,19 @@ class ProgressBar(BaseLogger):
# Progress bar will looks like
# Epoch [2/50]: [64/128] 50%|█████ , loss=0.123 [06:17<12:34]


Example where the State Attributes ``trainer.state.alpha`` and ``trainer.state.beta``
are also logged along with the NLL and Accuracy after each iteration:

.. code-block:: python

pbar.attach(
trainer,
metric_names=["nll", "accuracy"],
state_attributes=["alpha", "beta"],
)


Note:
When adding attaching the progress bar to an engine, it is recommend that you replace
every print operation in the engine's handlers triggered every iteration with
Expand All @@ -88,6 +101,9 @@ class ProgressBar(BaseLogger):
Due to `tqdm notebook bugs <https://github.com/tqdm/tqdm/issues/594>`_, bar format may be needed to be set
to an empty string value.

.. versionchanged:: 0.5.0
`attach` now accepts an optional list of `state_attributes`

"""

_events_order = [
Expand Down Expand Up @@ -161,6 +177,7 @@ def attach( # type: ignore[override]
output_transform: Optional[Callable] = None,
event_name: Union[Events, CallableEventWithFilter] = Events.ITERATION_COMPLETED,
closing_event_name: Union[Events, CallableEventWithFilter] = Events.EPOCH_COMPLETED,
state_attributes: Optional[List[str]] = None,
) -> None:
"""
Attaches the progress bar to an engine object.
Expand All @@ -176,6 +193,7 @@ def attach( # type: ignore[override]
:class:`~ignite.engine.events.Events`.
closing_event_name: event's name on which the progress bar is closed. Valid events are from
:class:`~ignite.engine.events.Events`.
state_attributes: list of attributes of the ``trainer.state`` to plot.

Note:
Accepted output value types are numbers, 0d and 1d torch tensors and strings.
Expand All @@ -193,7 +211,13 @@ def attach( # type: ignore[override]
if not self._compare_lt(event_name, closing_event_name):
raise ValueError(f"Logging event {event_name} should be called before closing event {closing_event_name}")

log_handler = _OutputHandler(desc, metric_names, output_transform, closing_event_name=closing_event_name)
log_handler = _OutputHandler(
desc,
metric_names,
output_transform,
closing_event_name=closing_event_name,
state_attributes=state_attributes,
)

super(ProgressBar, self).attach(engine, log_handler, event_name)
engine.add_event_handler(closing_event_name, self._close)
Expand All @@ -215,6 +239,7 @@ def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> Callable:
class _OutputHandler(BaseOutputHandler):
"""Helper handler to log engine's output and/or metrics

pbar = ProgressBar()
Args:
description: progress bar description.
metric_names: list of metric names to plot or a string "all" to plot all available
Expand All @@ -226,6 +251,7 @@ class _OutputHandler(BaseOutputHandler):
closing_event_name: event's name on which the progress bar is closed. Valid events are from
:class:`~ignite.engine.events.Events` or any `event_name` added by
:meth:`~ignite.engine.engine.Engine.register_events`.
state_attributes: list of attributes of the ``trainer.state`` to plot.

"""

Expand All @@ -235,11 +261,14 @@ def __init__(
metric_names: Optional[Union[str, List[str]]] = None,
output_transform: Optional[Callable] = None,
closing_event_name: Union[Events, CallableEventWithFilter] = Events.EPOCH_COMPLETED,
state_attributes: Optional[List[str]] = None,
):
if metric_names is None and output_transform is None:
# This helps to avoid 'Either metric_names or output_transform should be defined' of BaseOutputHandler
metric_names = []
super(_OutputHandler, self).__init__(description, metric_names, output_transform, global_step_transform=None)
super(_OutputHandler, self).__init__(
description, metric_names, output_transform, global_step_transform=None, state_attributes=state_attributes
)
self.closing_event_name = closing_event_name

@staticmethod
Expand Down
38 changes: 38 additions & 0 deletions tests/ignite/contrib/handlers/test_tqdm_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,44 @@ def step(engine, batch):
assert actual == expected


def test_pbar_with_state_attrs(capsys):

n_iters = 2
data = list(range(n_iters))
loss_values = iter(range(n_iters))

def step(engine, batch):
loss_value = next(loss_values)
return loss_value

trainer = Engine(step)
trainer.state.alpha = 3.899
trainer.state.beta = torch.tensor(12.21)
trainer.state.gamma = torch.tensor([21.0, 6.0])

RunningAverage(alpha=0.5, output_transform=lambda x: x).attach(trainer, "batchloss")

pbar = ProgressBar()
pbar.attach(trainer, metric_names=["batchloss",], state_attributes=["alpha", "beta", "gamma"])

trainer.run(data=data, max_epochs=1)

captured = capsys.readouterr()
err = captured.err.split("\r")
err = list(map(lambda x: x.strip(), err))
err = list(filter(None, err))
actual = err[-1]
if get_tqdm_version() < LooseVersion("4.49.0"):
expected = (
"Iteration: [1/2] 50%|█████ , batchloss=0.5, alpha=3.9, beta=12.2, gamma_0=21, gamma_1=6 [00:00<00:00]"
)
else:
expected = (
"Iteration: [1/2] 50%|█████ , batchloss=0.5, alpha=3.9, beta=12.2, gamma_0=21, gamma_1=6 [00:00<?]"
)
assert actual == expected


def test_pbar_no_metric_names(capsys):

n_epochs = 2
Expand Down