Skip to content

Commit

Permalink
Enableself.log in most functions. (#4969)
Browse files Browse the repository at this point in the history
* refactor

* solve pyright

* remove logging in batch_start functions

* update docs

* update doc

* resolve bug

* update

* correct script

* resolve on comments
  • Loading branch information
tchaton authored Dec 6, 2020
1 parent 9b1afa8 commit 2e838e6
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 63 deletions.
17 changes: 11 additions & 6 deletions docs/source/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

.. role:: hidden
:class: hidden-section

.. _logging:


Expand Down Expand Up @@ -57,9 +57,11 @@ Logging from a LightningModule

Lightning offers automatic log functionalities for logging scalars, or manual logging for anything else.

Automatic logging
Automatic Logging
=================
Use the :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method to log from anywhere in a :ref:`lightning_module`.
Use the :func:`~~pytorch_lightning.core.lightning.LightningModule.log`
method to log from anywhere in a :ref:`lightning_module` and :ref:`callbacks`
except functions with `batch_start` in their names.

.. code-block:: python
Expand Down Expand Up @@ -95,6 +97,9 @@ The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a
argument of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` or in the graphs plotted to the logger of your choice.


If your work requires to log in an unsupported function, please open an issue with a clear description of why it is blocking you.


Manual logging
==============
If you want to log anything that is not a scalar, like histograms, text, images, etc... you may need to use the logger object directly.
Expand Down Expand Up @@ -144,8 +149,8 @@ Use the :func:`~pytorch_lightning.loggers.base.rank_zero_experiment` and :func:`
def experiment(self):
# Return the experiment object associated with this logger.
pass
@property

@property
def version(self):
# Return the experiment version, int or str.
return '0.1'
Expand Down Expand Up @@ -238,7 +243,7 @@ if you are using a logger. These defaults can be customized by overriding the
:func:`~pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict` hook in your module.

.. code-block:: python
def get_progress_bar_dict(self):
# don't show the version number
items = super().get_progress_bar_dict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,9 @@ def cache_result(self) -> None:
self._internals[fx_name].append(hook_result, dataloader_idx=dataloader_idx, extra_info=extra_info)

# update logged_metrics, progress_bar_metrics, callback_metrics
self.update_logger_connector()

if "epoch_end" in fx_name:
self.update_logger_connector()

self.reset_model()

Expand All @@ -355,18 +357,19 @@ def update_logger_connector(self) -> None:
logger_connector = self.trainer.logger_connector

callback_metrics = {}
is_train = self._stage == LoggerStages.TRAIN
batch_pbar_metrics = {}
batch_log_metrics = {}
is_train = self._stage in LoggerStages.TRAIN.value

if not self._has_batch_loop_finished:
# get pbar
batch_pbar_metrics = self.get_latest_batch_pbar_metrics()
logger_connector.add_progress_bar_metrics(batch_pbar_metrics)
batch_log_metrics = self.get_latest_batch_log_metrics()

if is_train:
# Only log and add to callback epoch step during evaluation, test.
batch_log_metrics = self.get_latest_batch_log_metrics()
logger_connector.logged_metrics.update(batch_log_metrics)

callback_metrics.update(batch_pbar_metrics)
callback_metrics.update(batch_log_metrics)
else:
Expand All @@ -393,6 +396,9 @@ def update_logger_connector(self) -> None:
logger_connector.callback_metrics.update(callback_metrics)
logger_connector.callback_metrics.pop("epoch", None)

batch_pbar_metrics.pop("debug_epoch", None)
return batch_pbar_metrics, batch_log_metrics

def run_batch_from_func_name(self, func_name) -> Dict:
results = [getattr(hook_result, func_name) for hook_result in self._internals.values()]
results = [func(include_forked_originals=False) for func in results]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -587,11 +587,13 @@ def __gather_result_across_time_and_optimizers(self, epoch_output):
return gathered_epoch_outputs

def log_train_step_metrics(self, batch_output):
_, batch_log_metrics = self.cached_results.update_logger_connector()
# when metrics should be logged
if self.should_update_logs or self.trainer.fast_dev_run:
# logs user requested information to logger
metrics = self.cached_results.get_latest_batch_log_metrics()
grad_norm_dic = batch_output.grad_norm_dic
if len(metrics) > 0 or len(grad_norm_dic) > 0:
self.log_metrics(metrics, grad_norm_dic, log_train_step_metrics=True)
self.callback_metrics.update(metrics)
if grad_norm_dic is None:
grad_norm_dic = {}
if len(batch_log_metrics) > 0 or len(grad_norm_dic) > 0:
self.log_metrics(batch_log_metrics, grad_norm_dic, log_train_step_metrics=True)
self.callback_metrics.update(batch_log_metrics)
14 changes: 6 additions & 8 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def on_evaluation_model_train(self, *args, **kwargs):

def on_evaluation_end(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_end', *args, capture=True, **kwargs)
self.trainer.call_hook('on_test_end', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_end', *args, capture=True, **kwargs)
self.trainer.call_hook('on_validation_end', *args, **kwargs)

def reload_evaluation_dataloaders(self):
model = self.trainer.get_model()
Expand Down Expand Up @@ -329,9 +329,9 @@ def store_predictions(self, output, batch_idx, dataloader_idx):
def on_evaluation_epoch_end(self, *args, **kwargs):
# call the callback hook
if self.testing:
self.trainer.call_hook('on_test_epoch_end', *args, capture=True, **kwargs)
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_epoch_end', *args, capture=True, **kwargs)
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)

def log_evaluation_step_metrics(self, output, batch_idx):
if self.trainer.running_sanity_check:
Expand All @@ -346,10 +346,8 @@ def log_evaluation_step_metrics(self, output, batch_idx):
self.__log_result_step_metrics(step_log_metrics, step_pbar_metrics, batch_idx)

def __log_result_step_metrics(self, step_log_metrics, step_pbar_metrics, batch_idx):
cached_batch_log_metrics = \
self.trainer.logger_connector.cached_results.get_latest_batch_log_metrics()
cached_batch_pbar_metrics = \
self.trainer.logger_connector.cached_results.get_latest_batch_pbar_metrics()
cached_results = self.trainer.logger_connector.cached_results
cached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector()

step_log_metrics.update(cached_batch_log_metrics)
step_pbar_metrics.update(cached_batch_pbar_metrics)
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,8 @@ def call_setup_hook(self, model):
model.setup(stage_name)

def _reset_result_and_set_hook_fx_name(self, hook_name):
if "batch_start" in hook_name:
return True
model_ref = self.get_model()
if model_ref is not None:
# used to track current hook name called
Expand All @@ -868,10 +870,9 @@ def _cache_logged_metrics(self):
# capture logging for this hook
self.logger_connector.cache_logged_metrics()

def call_hook(self, hook_name, *args, capture=False, **kwargs):
def call_hook(self, hook_name, *args, **kwargs):
# set hook_name to model + reset Result obj
if capture:
self._reset_result_and_set_hook_fx_name(hook_name)
skip = self._reset_result_and_set_hook_fx_name(hook_name)

# always profile hooks
with self.profiler.profile(hook_name):
Expand All @@ -894,7 +895,7 @@ def call_hook(self, hook_name, *args, capture=False, **kwargs):
accelerator_hook = getattr(self.accelerator_backend, hook_name)
output = accelerator_hook(*args, **kwargs)

if capture:
if not skip:
self._cache_logged_metrics()
return output

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,8 +825,8 @@ def run_on_epoch_end_hook(self, epoch_output):
# inform logger the batch loop has finished
self.trainer.logger_connector.on_train_epoch_end()

self.trainer.call_hook('on_epoch_end', capture=True)
self.trainer.call_hook('on_train_epoch_end', epoch_output, capture=True)
self.trainer.call_hook('on_epoch_end')
self.trainer.call_hook('on_train_epoch_end', epoch_output)

def increment_accumulated_grad_global_step(self):
num_accumulated_batches_reached = self._accumulated_batches_reached()
Expand Down
13 changes: 13 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in vX.Y.Z"""
import sys
from argparse import ArgumentParser
Expand Down
27 changes: 8 additions & 19 deletions tests/trainer/logging_tests/test_eval_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,6 @@ def make_logging(self, pl_module, func_name,
"forked": False,
"func_name": func_name}

"""
def on_validation_start(self, trainer, pl_module):
self.make_logging(pl_module, 'on_validation_start', 1, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
Expand All @@ -486,13 +484,15 @@ def on_validation_epoch_start(self, trainer, pl_module):
self.make_logging(pl_module, 'on_validation_epoch_start', 3, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)

"""
def on_batch_start(self, trainer, pl_module):
self.make_logging(pl_module, 'on_batch_start', 4, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self.make_logging(pl_module, 'on_validation_batch_start', 5, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
"""

def on_batch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices,
Expand All @@ -510,8 +510,6 @@ def on_epoch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False],
on_epochs=self.choices, prob_bars=self.choices)

"""

def on_validation_epoch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_validation_epoch_end', 9, on_steps=[False],
on_epochs=self.choices, prob_bars=self.choices)
Expand Down Expand Up @@ -541,16 +539,14 @@ def validation_step(self, batch, batch_idx):
trainer.fit(model)
trainer.test()

"""
assert test_callback.funcs_called_count["on_epoch_start"] == 1
assert test_callback.funcs_called_count["on_batch_start"] == 1
# assert test_callback.funcs_called_count["on_batch_start"] == 1
assert test_callback.funcs_called_count["on_batch_end"] == 1
assert test_callback.funcs_called_count["on_validation_start"] == 1
assert test_callback.funcs_called_count["on_validation_epoch_start"] == 1
assert test_callback.funcs_called_count["on_validation_batch_start"] == 4
# assert test_callback.funcs_called_count["on_validation_batch_start"] == 4
assert test_callback.funcs_called_count["on_validation_batch_end"] == 4
assert test_callback.funcs_called_count["on_epoch_end"] == 1
"""
assert test_callback.funcs_called_count["on_validation_epoch_end"] == 1

# Make sure the func_name exists within callback_metrics. If not, we missed some
Expand Down Expand Up @@ -662,7 +658,6 @@ def make_logging(self, pl_module, func_name,
"forked": False,
"func_name": func_name}

"""
def on_test_start(self, trainer, pl_module):
self.make_logging(pl_module, 'on_test_start', 1, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
Expand All @@ -675,11 +670,8 @@ def on_test_epoch_start(self, trainer, pl_module):
self.make_logging(pl_module, 'on_test_epoch_start', 3, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)

def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self.make_logging(pl_module, 'on_test_batch_start', 4, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
def on_test_step_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self.make_logging(pl_module, 'on_test_step_end', 5, on_steps=self.choices,
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self.make_logging(pl_module, 'on_test_batch_end', 5, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)

# used to make sure aggregation works fine.
Expand All @@ -690,7 +682,6 @@ def on_test_step_end(self, trainer, pl_module, outputs, batch, batch_idx, datalo
def on_epoch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_epoch_end', 6, on_steps=[False],
on_epochs=self.choices, prob_bars=self.choices)
"""

def on_test_epoch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_test_epoch_end', 7, on_steps=[False],
Expand Down Expand Up @@ -728,13 +719,11 @@ def test_dataloader(self):
)
trainer.fit(model)
trainer.test()
"""

assert test_callback.funcs_called_count["on_test_start"] == 1
assert test_callback.funcs_called_count["on_epoch_start"] == 2
assert test_callback.funcs_called_count["on_test_epoch_start"] == 1
assert test_callback.funcs_called_count["on_test_batch_start"] == 4
assert test_callback.funcs_called_count["on_test_step_end"] == 4
"""
assert test_callback.funcs_called_count["on_test_batch_end"] == 4
assert test_callback.funcs_called_count["on_test_epoch_end"] == 1

# Make sure the func_name exists within callback_metrics. If not, we missed some
Expand Down
17 changes: 1 addition & 16 deletions tests/trainer/logging_tests/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def make_logging(self, pl_module: pl.LightningModule, func_name, func_idx,
"prog_bar": prog_bar,
"forked": False,
"func_name": func_name}
"""

def on_train_start(self, trainer, pl_module):
self.make_logging(pl_module, 'on_train_start', 1, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
Expand All @@ -571,15 +571,6 @@ def on_train_epoch_start(self, trainer, pl_module):
self.make_logging(pl_module, 'on_train_epoch_start', 3, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)

def on_batch_start(self, trainer, pl_module):
self.make_logging(pl_module, 'on_batch_start', 4, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self.make_logging(pl_module, 'on_train_batch_start', 5, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
def on_batch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
Expand All @@ -592,7 +583,6 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
# with func = np.mean if on_epoch else func = np.max
self.count += 1

"""
def on_epoch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False],
on_epochs=self.choices, prob_bars=self.choices)
Expand Down Expand Up @@ -629,17 +619,12 @@ def training_step(self, batch, batch_idx):
)
trainer.fit(model)

"""
assert test_callback.funcs_called_count["on_train_start"] == 1
assert test_callback.funcs_called_count["on_epoch_start"] == 2
assert test_callback.funcs_called_count["on_train_epoch_start"] == 2
assert test_callback.funcs_called_count["on_batch_start"] == 4
assert test_callback.funcs_called_count["on_train_batch_start"] == 4
assert test_callback.funcs_called_count["on_batch_end"] == 4
assert test_callback.funcs_called_count["on_epoch_end"] == 2
assert test_callback.funcs_called_count["on_train_batch_end"] == 4
"""
assert test_callback.funcs_called_count["on_epoch_end"] == 2
assert test_callback.funcs_called_count["on_train_epoch_end"] == 2

Expand Down

0 comments on commit 2e838e6

Please sign in to comment.