Skip to content

Commit

Permalink
Prevent loss to be moved to the cpu before backward call. (#9308)
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock committed Sep 7, 2021
1 parent 65822c3 commit 3a7dd42
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error handling in DDP process reconciliation when `_sync_dir` was not initialized ([#9267](https://github.com/PyTorchLightning/pytorch-lightning/pull/9267))


- Fixed `move_metrics_to_cpu` moving the loss on cpu while training on device ([#9308](https://github.com/PyTorchLightning/pytorch-lightning/pull/9308))


## [1.4.5] - 2021-08-31

- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))
Expand Down
194 changes: 194 additions & 0 deletions pytorch_lightning/loops/utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, Mapping, Optional, Sequence, Tuple

import torch
from torch import Tensor
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import STEP_OUTPUT


def check_finite_loss(loss: Optional[torch.Tensor]) -> None:
"""Checks for finite loss value.
Args:
loss: the loss value to check to be finite
"""
if loss is not None and not torch.isfinite(loss).all():
raise ValueError(f"The loss returned in `training_step` is {loss}.")


def _check_training_step_output(model: "pl.LightningModule", training_step_output: STEP_OUTPUT) -> None:
"""Sanity checks that training produced a valid output and optimizer step has already been called in manual
optimization.
Args:
model: a reference to the trainer
training_step_output: the output of the training step (before wrapping in an AttributeDict)
"""
if isinstance(training_step_output, torch.Tensor) and not model.automatic_optimization:
if training_step_output.grad_fn is None:
# TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")
elif model.automatic_optimization:
if not any(
(
isinstance(training_step_output, torch.Tensor),
(isinstance(training_step_output, Mapping) and "loss" in training_step_output),
training_step_output is None,
)
):
raise MisconfigurationException(
"In automatic optimization, `training_step` must either return a Tensor, "
"a dict with key 'loss' or None (where the step will be skipped)."
)


def _process_training_step_output(
trainer: "pl.Trainer", training_step_output: STEP_OUTPUT
) -> Tuple[Optional[ResultCollection], Optional[Any]]:
"""Adds the :param:`training_step_output` to the trainer's results
Args:
trainer: a reference to the trainer
training_step_output: the output of the training step (before wrapping into an AttributeDict)
Returns:
the updated results (None if the training_step's output was None) and hiddens exract from the results
"""
if training_step_output is None:
return None, None

results = trainer._results

loss = None
hiddens = None

# handle dict return
if isinstance(training_step_output, dict):
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
loss = training_step_output.get("loss")
hiddens = training_step_output.get("hiddens")
# detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time`
hiddens = apply_to_collection(hiddens, torch.Tensor, lambda t: t.detach())
# use the setter instead of `dict.update` because it calls `detach` on the tensor items
results.extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}

# handle scalar return
elif isinstance(training_step_output, torch.Tensor):
loss = training_step_output

if trainer.terminate_on_nan:
check_finite_loss(loss)

# the loss shouldn't be moved to cpu.
if trainer.move_metrics_to_cpu:
results.cpu()

# map to results under the hood
results.minimize = loss

return results, hiddens


def _build_training_step_kwargs(
lightning_module: "pl.LightningModule",
optimizers: Sequence[Optimizer],
batch: Any,
batch_idx: int,
opt_idx: Optional[int],
hiddens: Optional[Tensor],
) -> Dict[str, Any]:
"""Builds the keyword arguments for training_step
Args:
lightning_module: the LightningModule with a `training_step` hook implementation
optimizers: the list of optimizers from the Trainer
batch: the batch to train on
batch_idx: the index of the current batch
opt_idx: the index of the current optimizer
hiddens: the hidden state of the previous RNN iteration
Returns:
the keyword arguments for the training step
"""
# enable not needing to add opt_idx to training_step
step_kwargs = OrderedDict([("batch", batch)])

training_step_fx = getattr(lightning_module, "training_step")

if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2):
step_kwargs["batch_idx"] = batch_idx

if len(optimizers) > 1:
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
if has_opt_idx_in_train_step:
if not lightning_module.automatic_optimization:
raise ValueError(
"Your `LightningModule.training_step` signature contains an `optimizer_idx` argument but"
" in manual optimization optimizers must be handled by the user. Remove the optimizer_idx"
" argument or set `self.automatic_optimization = True`."
)
step_kwargs["optimizer_idx"] = opt_idx
elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization:
raise ValueError(
f"Your LightningModule defines {len(optimizers)} optimizers but"
" `training_step` is missing the `optimizer_idx` argument."
)

# pass hiddens if using tbptt
if lightning_module.truncated_bptt_steps > 0:
step_kwargs["hiddens"] = hiddens

return step_kwargs


def _prepare_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator:
"""Attach the dataloader"""
if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
# restore iteration
dataloader_iter = enumerate(data_fetcher, batch_idx)
else:
dataloader_iter = iter(data_fetcher)
return dataloader_iter


@contextmanager
def _block_parallel_sync_behavior(trainer: "pl.Trainer", block: bool = True) -> Generator[None, None, None]:
"""
Blocks synchronization in :class:`~pytorch_lightning.plugins.training_type.parallel.ParallelPlugin`.
This is useful for example when when accumulating gradients to reduce communication when it is not needed.
Args:
trainer: the trainer instance with a reference to a training type plugin
block: whether the context manager is enabled or not
Returns:
context manager with sync behaviour off
"""
if isinstance(trainer.training_type_plugin, ParallelPlugin) and block:
with trainer.training_type_plugin.block_backward_sync():
yield None
else:
yield None
17 changes: 17 additions & 0 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,3 +702,20 @@ def test_log_gpu_memory_without_logging_on_step(tmpdir, log_gpu_memory):
assert "max_gpu_mem" in trainer.logged_metrics
else:
assert "gpu_id: 1/memory.used (MB)" in trainer.logged_metrics


@RunIf(min_gpus=1)
def test_move_metrics_to_cpu(tmpdir):
class TestModel(BoringModel):
def on_before_backward(self, loss: torch.Tensor) -> None:
assert loss.device.type == "cuda"

trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
amp_backend="native",
precision=16,
move_metrics_to_cpu=True,
gpus=1,
)
trainer.fit(TestModel())

0 comments on commit 3a7dd42

Please sign in to comment.