Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logging #4880

Merged
merged 13 commits into from
Nov 27, 2020
4 changes: 2 additions & 2 deletions benchmarks/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from tests.base.models import ParityModuleMNIST, ParityModuleRNN


# TODO: explore where the time leak comes from
# ParityModuleMNIST runs with num_workers=1
@pytest.mark.parametrize('cls_model,max_diff', [
(ParityModuleRNN, 0.05),
(ParityModuleMNIST, 0.99)
(ParityModuleMNIST, 0.22)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WOW :]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh no, I increased num_workers to 1 !

])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_pytorch_parity(tmpdir, cls_model, max_diff):
Expand Down
50 changes: 41 additions & 9 deletions pytorch_lightning/profiler/profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import fsspec
import numpy as np

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.cloud_io import get_filesystem

Expand Down Expand Up @@ -121,14 +122,15 @@ class SimpleProfiler(BaseProfiler):
the mean duration of each action and the total time spent over the entire training run.
"""

def __init__(self, output_filename: Optional[str] = None):
def __init__(self, output_filename: Optional[str] = None, extended=True):
"""
Args:
output_filename: optionally save profile results to file instead of printing
to std out when training is finished.
"""
self.current_actions = {}
self.recorded_durations = defaultdict(list)
self.extended = extended

self.output_fname = output_filename
self.output_file = None
Expand All @@ -137,6 +139,7 @@ def __init__(self, output_filename: Optional[str] = None):
self.output_file = fs.open(self.output_fname, "w")

streaming_out = [self.output_file.write] if self.output_file else [log.info]
self.start_time = time.monotonic()
super().__init__(output_streams=streaming_out)

def start(self, action_name: str) -> None:
Expand All @@ -156,18 +159,47 @@ def stop(self, action_name: str) -> None:
duration = end_time - start_time
self.recorded_durations[action_name].append(duration)

def make_report(self):
total_duration = time.monotonic() - self.start_time
report = [[a, d, 100. * np.sum(d) / total_duration] for a, d in self.recorded_durations.items()]
report.sort(key=lambda x: x[2], reverse=True)
return report, total_duration

def summary(self) -> str:
output_string = "\n\nProfiler Report\n"

def log_row(action, mean, total):
return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}"
if self.extended:

if len(self.recorded_durations) > 0:
max_key = np.max([len(k) for k in self.recorded_durations.keys()])

def log_row(action, mean, num_calls, total, per):
row = f"{os.linesep}{action:<{max_key}s}\t| {mean:<15}\t|"
row += f"{num_calls:<15}\t| {total:<15}\t| {per:<15}\t|"
return row

output_string += log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %")
output_string_len = len(output_string)
output_string += f"{os.linesep}{'-' * output_string_len}"
report, total_duration = self.make_report()
output_string += log_row("Total", "-", "_", f"{total_duration:.5}", "100 %")
output_string += f"{os.linesep}{'-' * output_string_len}"
for action, durations, duration_per in report:
output_string += log_row(
action, f"{np.mean(durations):.5}", f"{len(durations):}",
f"{np.sum(durations):.5}", f"{duration_per:.5}"
)
else:
def log_row(action, mean, total):
return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}"

output_string += log_row("Action", "Mean duration (s)", "Total time (s)")
output_string += f"{os.linesep}{'-' * 65}"
for action, durations in self.recorded_durations.items():
output_string += log_row(
action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}"
)
output_string += log_row("Action", "Mean duration (s)", "Total time (s)")
output_string += f"{os.linesep}{'-' * 65}"

for action, durations in self.recorded_durations.items():
output_string += log_row(
action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}"
)
output_string += os.linesep
return output_string

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,17 +372,22 @@ def cache_result(self) -> None:
This function is called after every hook
and store the result object
"""
model_ref = self.trainer.get_model()
with self.trainer.profiler.profile("cache_result"):
model_ref = self.trainer.get_model()

# extract hook results
hook_result = model_ref._results

# extract hook results
hook_result = model_ref._results
if len(hook_result) == 1:
model_ref._current_hook_fx_name = None
model_ref._current_fx_name = ''
return

# extract model information
fx_name, dataloader_idx = self.current_model_info()
# extract model information
fx_name, dataloader_idx = self.current_model_info()

# add only if anything as been logged
# default len is 1 due to _internals
if len(hook_result) > 1:
# add only if anything as been logged
# default len is 1 due to _internals

if fx_name not in self._internals:
self._internals[fx_name] = HookResultStore(fx_name)
Expand All @@ -406,8 +411,7 @@ def cache_result(self) -> None:
# update logged_metrics, progress_bar_metrics, callback_metrics
self.update_logger_connector(fx_name)

# reset _results, fx_name
self.reset_model()
self.reset_model()

def update_logger_connector(self, fx_name: str = None) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_warn, TPU_AVAILABLE
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down
14 changes: 8 additions & 6 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.core.step_result import Result, EvalResult
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities.warning_utils import WarningCache


Expand Down Expand Up @@ -105,9 +106,9 @@ def on_evaluation_model_train(self, *args, **kwargs):

def on_evaluation_end(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_end', *args, **kwargs)
self.trainer.call_hook('on_test_end', *args, capture=True, **kwargs)
else:
self.trainer.call_hook('on_validation_end', *args, **kwargs)
self.trainer.call_hook('on_validation_end', *args, capture=True, **kwargs)

def reload_evaluation_dataloaders(self):
model = self.trainer.get_model()
Expand Down Expand Up @@ -167,6 +168,7 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx):
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)

model_ref = self.trainer.get_model()
model_ref._results = Result()
# run actual test step
if self.testing:
model_ref._current_fx_name = "test_step"
Expand Down Expand Up @@ -327,9 +329,9 @@ def store_predictions(self, output, batch_idx, dataloader_idx):
def on_evaluation_epoch_end(self, *args, **kwargs):
# call the callback hook
if self.testing:
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
self.trainer.call_hook('on_test_epoch_end', *args, capture=True, **kwargs)
else:
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
self.trainer.call_hook('on_validation_epoch_end', *args, capture=True, **kwargs)

def log_evaluation_step_metrics(self, output, batch_idx):
if self.trainer.running_sanity_check:
Expand Down
23 changes: 14 additions & 9 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,9 @@ def train(self):
# hook
self.train_loop.on_train_epoch_start(epoch)

# run train epoch
self.train_loop.run_training_epoch()
with self.profiler.profile("run_training_epoch"):
# run train epoch
self.train_loop.run_training_epoch()

if self.max_steps and self.max_steps <= self.global_step:

Expand Down Expand Up @@ -604,8 +605,9 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None):
self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)

# lightning module methods
output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
output = self.evaluation_loop.evaluation_step_end(output)
with self.profiler.profile("evaluation_step_and_end"):
output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
output = self.evaluation_loop.evaluation_step_end(output)

# hook + store predictions
self.evaluation_loop.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx)
Expand Down Expand Up @@ -656,7 +658,8 @@ def track_output_for_epoch_end(self, outputs, output):
def run_test(self):
# only load test dataloader for testing
# self.reset_test_dataloader(ref_model)
eval_loop_results, _ = self.run_evaluation(test_mode=True)
with self.profiler.profile("run_test_evaluation"):
eval_loop_results, _ = self.run_evaluation(test_mode=True)

if len(eval_loop_results) == 0:
return 1
Expand Down Expand Up @@ -863,16 +866,18 @@ def _reset_result_and_set_hook_fx_name(self, hook_name):
# used to track current hook name called
model_ref._results = Result()
model_ref._current_hook_fx_name = hook_name
return False

def _cache_logged_metrics(self):
model_ref = self.get_model()
if model_ref is not None:
# capture logging for this hook
self.logger_connector.cache_logged_metrics()

def call_hook(self, hook_name, *args, **kwargs):
def call_hook(self, hook_name, *args, capture=False, **kwargs):
# set hook_name to model + reset Result obj
self._reset_result_and_set_hook_fx_name(hook_name)
if capture:
self._reset_result_and_set_hook_fx_name(hook_name)

# always profile hooks
with self.profiler.profile(hook_name):
Expand All @@ -895,6 +900,6 @@ def call_hook(self, hook_name, *args, **kwargs):
accelerator_hook = getattr(self.accelerator_backend, hook_name)
output = accelerator_hook(*args, **kwargs)

# capture logging
self._cache_logged_metrics()
if capture:
self._cache_logged_metrics()
return output
49 changes: 26 additions & 23 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
from pytorch_lightning.utilities import parsing, AMPType
from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum
from pytorch_lightning.utilities import AMPType, parsing
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
Expand Down Expand Up @@ -320,7 +320,9 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens)

# manually capture logged metrics
model_ref._results = Result()
model_ref._current_fx_name = 'training_step'
model_ref._results = Result()
training_step_output = self.trainer.accelerator_backend.training_step(args)
self.trainer.logger_connector.cache_logged_metrics()

Expand Down Expand Up @@ -536,7 +538,8 @@ def run_training_epoch(self):
# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)

# when returning -1 from train_step, we end epoch early
if batch_output.signal == -1:
Expand Down Expand Up @@ -766,27 +769,28 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
"""
wrap the forward step in a closure so second order methods work
"""
# lightning module hook
result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
self._curr_step_result = result
with self.trainer.profiler.profile("training_step_and_backward"):
# lightning module hook
result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
self._curr_step_result = result

if result is None:
self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...")
return None
if result is None:
self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...")
return None

if self.trainer.train_loop.automatic_optimization:
# backward pass
with self.trainer.profiler.profile("model_backward"):
self.backward(result, optimizer, opt_idx)
if self.trainer.train_loop.automatic_optimization:
# backward pass
with self.trainer.profiler.profile("model_backward"):
self.backward(result, optimizer, opt_idx)

# hook - call this hook only
# when gradients have finished to accumulate
if not self.should_accumulate():
self.on_after_backward(result.training_step_output, batch_idx, result.loss)
# hook - call this hook only
# when gradients have finished to accumulate
if not self.should_accumulate():
self.on_after_backward(result.training_step_output, batch_idx, result.loss)

# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
self.trainer.detect_nan_tensors(result.loss)
# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
self.trainer.detect_nan_tensors(result.loss)

return result

Expand Down Expand Up @@ -814,9 +818,8 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None):
self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics)

def run_on_epoch_end_hook(self, epoch_output):
self.trainer.call_hook('on_epoch_end')
self.trainer.call_hook('on_train_epoch_end', epoch_output)

self.trainer.call_hook('on_epoch_end', capture=True)
self.trainer.call_hook('on_train_epoch_end', epoch_output, capture=True)
self.trainer.logger_connector.on_train_epoch_end()

def increment_accumulated_grad_global_step(self):
Expand Down
5 changes: 2 additions & 3 deletions tests/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.utils.data import DataLoader

from pytorch_lightning.core.lightning import LightningModule
from tests.base.datasets import TrialMNIST, AverageDataset, MNIST
from tests.base.datasets import MNIST, AverageDataset, TrialMNIST


class Generator(nn.Module):
Expand Down Expand Up @@ -217,5 +217,4 @@ def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)

def train_dataloader(self):
return DataLoader(MNIST(train=True, download=True,),
batch_size=128)
return DataLoader(MNIST(train=True, download=True,), batch_size=128, num_workers=1)
Loading