Skip to content

Commit

Permalink
[feat] Logging refactor 2/n - train (#4495)
Browse files Browse the repository at this point in the history
* update logging

* solve more bugs

* replace Mapping by Dict

* update on comments

* resolve pep8

* Apply suggestions from code review

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>

* Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* update on comments

* typo

* update for coverage

* update test

* update

* Update tests/models/test_hooks.py

Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>

* Update tests/models/test_hooks.py

Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>

* update on comments

* remove deepcopy

* remove useless look for

* another small optim

* extra optim

* remove lastest optim, can be source of bug

* resolve bug

* add docstring

* optimize coverage

* Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/trainer/logging_tests/test_distributed_logging.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/trainer/evaluation_loop.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/trainer/logging/test_logger_connector.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/trainer/logging_tests/test_train_loop_logging_1_0.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* update on comments

* update

* update on comments

* update parity speed

* get it down to 0.65

* update

* 0.8 max_dif

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
5 people authored Nov 5, 2020
1 parent 62ea461 commit 9c8701f
Show file tree
Hide file tree
Showing 15 changed files with 733 additions and 257 deletions.
2 changes: 1 addition & 1 deletion benchmarks/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@pytest.mark.parametrize('cls_model,max_diff', [
(ParityModuleRNN, 0.05),
(ParityModuleMNIST, 0.70)
(ParityModuleMNIST, 0.8)
])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_pytorch_parity(tmpdir, cls_model, max_diff):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _on_validation_start_log():
@staticmethod
def _on_validation_end_log():
"""Called when the validation loop ends."""
return {"on_step": [False], "on_epoch": [False, True]}
return None

@staticmethod
def _on_test_start_log():
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,14 @@ def __init__(self, trainer):
self._callback_hook_validator = CallbackHookNameValidator()
self._current_stage = None

def cached_results(self, stage_or_testing: Union[str, bool]) -> Union[EpochResultStore, None]:
""" Function to access cached_results using str or bool. Bool is used only for testing"""
stage_or_testing = str(stage_or_testing)
stages = self._stages
if stage_or_testing in self._stages:
return self._cached_results[stage_or_testing]
if stage_or_testing in LOOKUP_TABLE:
# Acces using trainer.testing
stage = LOOKUP_TABLE[stage_or_testing]
return self._cached_results[stage]
raise MisconfigurationException(
f"Provide stage_or_testing {stage_or_testing} doesn't belong either to {self._stages}"
f" or {LOOKUP_TABLE.keys()}"
)
@property
def cached_results(self) -> Union[EpochResultStore, None]:
return self._cached_results[self._current_stage]

def set_stage(self, stage_or_testing: str, reset:bool = False) -> None:
self._current_stage = self._determine_stage(stage_or_testing)
if reset:
self.cached_results(stage_or_testing).reset()
self.cached_results.reset()

def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoch: bool = None) -> None:
self._callback_hook_validator.check_logging_in_callbacks(current_hook_fx_name=hook_fx_name,
Expand All @@ -75,17 +64,17 @@ def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataload
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None

# track batch_size
self.cached_results(testing)._batch_size = Result.extract_batch_size(batch)
self.cached_results._batch_size = Result.extract_batch_size(batch)

def on_batch_start(self, split_idx: int, opt_idx: int, split_batch) -> None:
self._cached_results["train"]._split_idx = split_idx
self._cached_results["train"]._opt_idx = opt_idx
self._cached_results["train"]._batch_size = Result.extract_batch_size(split_batch)
def on_train_split_start(self, split_idx: int, opt_idx: int, split_batch) -> None:
self.cached_results._split_idx = split_idx
self.cached_results._opt_idx = opt_idx
self.cached_results._batch_size = Result.extract_batch_size(split_batch)

def on_train_batch_end(self) -> None:
self._cached_results["train"]._split_idx = None
self._cached_results["train"]._opt_idx = None
self._cached_results["train"]._batch_size = None
self.cached_results._split_idx = None
self.cached_results._opt_idx = None
self.cached_results._batch_size = None

def _determine_stage(self, stage_or_testing: Union[str, bool]) -> str:
stage_or_testing = str(stage_or_testing)
Expand All @@ -112,6 +101,16 @@ def on_trainer_init(self, logger, flush_logs_every_n_steps, log_every_n_steps):
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
self.trainer.log_every_n_steps = log_every_n_steps

@property
def should_flush_logs(self):
should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0
return should_flush or self.trainer.should_stop

@property
def should_update_logs(self):
should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
return should_log_every_n_steps or self.trainer.should_stop

def configure_logger(self, logger):
if logger is True:
version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id)
Expand All @@ -130,6 +129,53 @@ def configure_logger(self, logger):
else:
self.trainer.logger = logger

def cache_training_step_metrics(self, opt_closure_result):
"""
This function is responsible to update
logger_connector internals metrics holder based for depreceated logging
"""
using_results_obj = isinstance(opt_closure_result.training_step_output, Result)

# temporary dict to collect metrics
logged_metrics_tmp = {}
pbar_metrics_tmp = {}
callback_metrics_tmp = {}

if using_results_obj:
batch_log_metrics = opt_closure_result.training_step_output.get_batch_log_metrics(
include_forked_originals=False
)
logged_metrics_tmp.update(batch_log_metrics)

batch_pbar_metrics = opt_closure_result.training_step_output.get_batch_pbar_metrics(
include_forked_originals=False
)
pbar_metrics_tmp.update(batch_pbar_metrics)

forked_metrics = opt_closure_result.training_step_output.get_forked_metrics()
callback_metrics_tmp.update(forked_metrics)
callback_metrics_tmp.update(logged_metrics_tmp)

else:
batch_log_metrics = opt_closure_result.training_step_output.log_metrics
logged_metrics_tmp.update(batch_log_metrics)

callback_metrics = opt_closure_result.training_step_output.callback_metrics
callback_metrics_tmp.update(callback_metrics)

batch_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end
pbar_metrics_tmp.update(batch_pbar_metrics)

# track progress bar metrics
if len(pbar_metrics_tmp) > 0:
self.add_progress_bar_metrics(pbar_metrics_tmp)

self.callback_metrics.update(callback_metrics_tmp)

# save legacy log metrics
self.logged_metrics.update(logged_metrics_tmp)
self.cached_results.legacy_batch_log_metrics.update(logged_metrics_tmp)

def log_metrics(self, metrics, grad_norm_dic, step=None):
"""Logs the metric dict passed in.
If `step` parameter is None and `step` key is presented is metrics,
Expand Down Expand Up @@ -396,8 +442,9 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mod
if num_loaders == 1:
self.__process_eval_epoch_end_results_and_log_legacy_update(prog_bar_metrics, log_metrics, callback_metrics)

def on_train_epoch_end(self, epoch_output):
pass
def on_train_epoch_end(self):
# inform cached logger connector epoch finished
self.cached_results.has_batch_loop_finished = True

def log_train_epoch_end_metrics(self,
epoch_output,
Expand Down Expand Up @@ -441,12 +488,10 @@ def log_train_epoch_end_metrics(self,
# ------------------
if is_1_0_result:
# lightning module hook
epoch_end_log_result = self.training_epoch_end(model, epoch_output, num_optimizers)
self.training_epoch_end(model, epoch_output, num_optimizers)

# log/aggregate metrics automatically
epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)
epoch_log_metrics.update(epoch_end_log_result.get_epoch_log_metrics())
epoch_progress_bar_metrics.update(epoch_end_log_result.get_epoch_pbar_metrics())

# TODO: deprecate 1.0
else:
Expand All @@ -459,6 +504,14 @@ def log_train_epoch_end_metrics(self,
)
epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics = out

# it will perform reduction over epoch and return log metrics
cached_epoch_log_metrics = self.cached_results.get_epoch_log_metrics()
cached_epoch_pbar_metrics = self.cached_results.get_epoch_pbar_metrics()

# update
epoch_log_metrics.update(cached_epoch_log_metrics)
epoch_progress_bar_metrics.update(cached_epoch_pbar_metrics)

# --------------------------
# track results
# --------------------------
Expand All @@ -475,15 +528,16 @@ def log_train_epoch_end_metrics(self,
self.add_progress_bar_metrics(epoch_progress_bar_metrics)
self.callback_metrics.update(epoch_progress_bar_metrics)

# reset epoch loop result for next epoch
self.cached_results.reset()

def training_epoch_end(self, model, epoch_output, num_optimizers):
if not is_overridden('training_epoch_end', model=model):
return Result()
return

# run training_epoch_end
# refresh the result for custom logging at the epoch level
model._current_fx_name = 'training_epoch_end'
model._results = Result()

epoch_output = self.__prepare_epoch_end_inputs(epoch_output)

if num_optimizers == 1 or not self.trainer.train_loop.automatic_optimization:
Expand All @@ -492,15 +546,11 @@ def training_epoch_end(self, model, epoch_output, num_optimizers):
# lightningmodule hook
epoch_output = model.training_epoch_end(epoch_output)

model._current_fx_name = ''

if epoch_output is not None:
raise MisconfigurationException('training_epoch_end expects a return of None. '
'HINT: remove the return statement in training_epoch_end')

# user can ALSO log at the end of an epoch
new_epoch_end_logs = model._results
return new_epoch_end_logs
# capture logging
self.trainer.logger_connector.cache_logged_metrics()

def __run_legacy_training_epoch_end(
self,
Expand All @@ -527,8 +577,12 @@ def __run_legacy_training_epoch_end(

# run training_epoch_end
# a list with a result per optimizer index
model._current_fx_name = 'training_epoch_end'
epoch_output = model.training_epoch_end(epoch_output)

# capture logging
self.trainer.logger_connector.cache_logged_metrics()

if isinstance(epoch_output, Result):
epoch_log_metrics = epoch_output.epoch_log_metrics
epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics
Expand Down Expand Up @@ -563,7 +617,7 @@ def __auto_reduce_results_on_epoch_end(self, epoch_output):
# reduce across training steps
opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs)

# with manual opt need 1+ metrics because meta is always there
# with manual opt need 1 + metrics because meta is always there
if opt_outputs.minimize is not None:
opt_outputs.minimize = opt_outputs.minimize.mean()
epoch_log_metrics.update(opt_outputs.epoch_log_metrics)
Expand Down Expand Up @@ -623,12 +677,9 @@ def __gather_result_across_time_and_optimizers(self, epoch_output):

def log_train_step_metrics(self, batch_output):
# when metrics should be logged
should_log_metrics = (
(self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 or self.trainer.should_stop
)
if should_log_metrics or self.trainer.fast_dev_run:
if self.should_update_logs or self.trainer.fast_dev_run:
# logs user requested information to logger
metrics = batch_output.batch_log_metrics
metrics = self.cached_results.get_latest_batch_log_metrics()
grad_norm_dic = batch_output.grad_norm_dic
if len(metrics) > 0 or len(grad_norm_dic) > 0:
self.log_metrics(metrics, grad_norm_dic)
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ def __log_result_step_metrics(self, output, batch_idx):
step_log_metrics = output.get_batch_log_metrics(include_forked_originals=False)
step_pbar_metrics = output.get_batch_pbar_metrics(include_forked_originals=False)

cached_batch_log_metrics = \
self.trainer.logger_connector.cached_results.get_latest_batch_log_metrics()

if len(step_log_metrics) > 0:
# make the metrics appear as a different line in the same graph
metrics_by_epoch = {}
Expand Down
24 changes: 23 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,25 @@ def call_setup_hook(self, model):
self.setup(stage_name)
model.setup(stage_name)

def _reset_result_and_set_hook_fx_name(self, hook_name):
model_ref = self.get_model()
if model_ref is not None:
# used to track current hook name called
model_ref._results = Result()
model_ref._current_hook_fx_name = hook_name

def _cache_logged_metrics(self):
model_ref = self.get_model()
if model_ref is not None:
# capture logging for this hook
self.logger_connector.cache_logged_metrics()

def call_hook(self, hook_name, *args, **kwargs):
# temporary. Don't modify evaluation behaviour
if self.logger_connector._current_stage == "train":
# set hook_name to model + reset Result obj
self._reset_result_and_set_hook_fx_name(hook_name)

# always profile hooks
with self.profiler.profile(hook_name):

Expand All @@ -860,4 +878,8 @@ def call_hook(self, hook_name, *args, **kwargs):
accelerator_hook = getattr(self.accelerator_backend, hook_name)
output = accelerator_hook(*args, **kwargs)

return output
# temporary. Don't modify evaluation behaviour
if self.logger_connector._current_stage == "train":
# capture logging
self._cache_logged_metrics()
return output
Loading

0 comments on commit 9c8701f

Please sign in to comment.