Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: inner train loop (intermediate step) 3/n #3362

Merged
merged 6 commits into from
Sep 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 4 additions & 62 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,73 +934,15 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
"""
wrap the forward step in a closure so second order methods work
"""
# ---------------------------
# FORWARD (TRAINING STEP + TRAIN STEP END)
# ---------------------------
with self.profiler.profile('model_forward'):
args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens)
training_step_output = self.accelerator_backend.training_step(args)
training_step_output = self.call_hook('training_step_end', training_step_output)

# ----------------------------
# PROCESS THE RESULT
# ----------------------------
# format and reduce outputs accordingly
training_step_output_for_epoch_end = training_step_output
is_result_obj = isinstance(training_step_output, Result)

# track batch size for weighted average
if is_result_obj:
training_step_output.track_batch_size(len(split_batch))

# don't allow EvalResult in the training_step
if isinstance(training_step_output, EvalResult):
raise MisconfigurationException('training_step cannot return EvalResult, '
'use a dict or TrainResult instead')

# handle regular dicts
if not is_result_obj:
training_step_output = self.process_output(training_step_output, train=True)

training_step_output = AttributeDict(
batch_loss=training_step_output[0],
pbar_on_batch_end=training_step_output[1],
log_metrics=training_step_output[2],
callback_metrics=training_step_output[3],
hiddens=training_step_output[4],
)

# if the user decides to finally reduce things in epoch_end, save raw output without graphs
if isinstance(training_step_output_for_epoch_end, torch.Tensor):
training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
elif is_result_obj:
training_step_output_for_epoch_end = copy(training_step_output)
training_step_output_for_epoch_end.detach()
else:
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)

# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss
closure_loss = closure_loss / self.accumulate_grad_batches

# the loss will get scaled for amp. avoid any modifications to it
untouched_loss = closure_loss.detach().clone()
# lightning module hook
result = self.train_loop.training_step(split_batch, batch_idx, opt_idx, hiddens)

# backward pass
with self.profiler.profile('model_backward'):
closure_loss = self.accelerator_backend.backward(closure_loss, optimizer, opt_idx)
self.train_loop.backward(result, optimizer, opt_idx)

# hook
self.train_loop.on_after_backward(training_step_output, batch_idx, untouched_loss)
self.train_loop.on_after_backward(result.training_step_output, batch_idx, result.loss)

# result
result = AttributeDict(
loss=untouched_loss,
training_step_output=training_step_output,
training_step_output_for_epoch_end=training_step_output_for_epoch_end,
hiddens=training_step_output.hiddens,
)
return result

def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
Expand Down
72 changes: 71 additions & 1 deletion pytorch_lightning/trainer/training_loop_temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer.supporters import Accumulator
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.step_result import Result
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.utilities.parsing import AttributeDict
from copy import copy


class TrainLoop:
Expand Down Expand Up @@ -130,6 +134,11 @@ def get_optimizers_iterable(self):
opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop)
return [(opt_idx, self.trainer.optimizers[opt_idx])]

def backward(self, result, optimizer, opt_idx):
# backward pass
with self.trainer.profiler.profile('model_backward'):
result.closure_loss = self.trainer.accelerator_backend.backward(result.closure_loss, optimizer, opt_idx)

def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
is_result_obj = isinstance(training_step_output, Result)

Expand All @@ -143,3 +152,64 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):

# when in dev debugging track the losses
self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach())

def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
with self.trainer.profiler.profile('model_forward'):
args = self.trainer.build_train_args(split_batch, batch_idx, opt_idx, hiddens)
training_step_output = self.trainer.accelerator_backend.training_step(args)
training_step_output = self.trainer.call_hook('training_step_end', training_step_output)

# ----------------------------
# PROCESS THE RESULT
# ----------------------------
# format and reduce outputs accordingly
training_step_output_for_epoch_end = training_step_output
is_result_obj = isinstance(training_step_output, Result)

# track batch size for weighted average
if is_result_obj:
training_step_output.track_batch_size(len(split_batch))

# don't allow EvalResult in the training_step
if isinstance(training_step_output, EvalResult):
raise MisconfigurationException('training_step cannot return EvalResult, '
'use a dict or TrainResult instead')

# handle regular dicts
if not is_result_obj:
training_step_output = self.trainer.process_output(training_step_output, train=True)

training_step_output = AttributeDict(
batch_loss=training_step_output[0],
pbar_on_batch_end=training_step_output[1],
log_metrics=training_step_output[2],
callback_metrics=training_step_output[3],
hiddens=training_step_output[4],
)

# if the user decides to finally reduce things in epoch_end, save raw output without graphs
if isinstance(training_step_output_for_epoch_end, torch.Tensor):
training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
elif is_result_obj:
training_step_output_for_epoch_end = copy(training_step_output)
training_step_output_for_epoch_end.detach()
else:
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)

# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss
closure_loss = closure_loss / self.trainer.accumulate_grad_batches

# the loss will get scaled for amp. avoid any modifications to it
untouched_loss = closure_loss.detach().clone()

# result
result = AttributeDict(
closure_loss=closure_loss,
loss=untouched_loss,
training_step_output=training_step_output,
training_step_output_for_epoch_end=training_step_output_for_epoch_end,
hiddens=training_step_output.hiddens,
)
return result