Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: inner train loop (intermediate step) 10/n #3369

Merged
merged 1 commit into from
Sep 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,10 +739,12 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
# track metrics to log
batch_log_metrics = []

# bookkeeping
using_results_obj = False
self.hiddens = None

# track all outputs across time and num of optimizers
batch_outputs = [[] for i in range(len(self.train_loop.get_optimizers_iterable()))]
batch_outputs = [[] for _ in range(len(self.train_loop.get_optimizers_iterable()))]

if batch is None:
return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic)
Expand All @@ -757,16 +759,13 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
if response == -1:
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)

splits = [batch]
if self.truncated_bptt_steps is not None:
model_ref = self.get_model()
with self.profiler.profile('tbptt_split_batch'):
splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps)
# lightning module hook
splits = self.train_loop.tbptt_split_batch(batch)

self.hiddens = None
for split_idx, split_batch in enumerate(splits):
self.split_idx = split_idx

# loop over optimizers
for opt_idx, optimizer in self.train_loop.get_optimizers_iterable():
# make sure only the gradients of the current optimizer's parameters are calculated
# in the training step to prevent dangling gradients in multiple-optimizer setup.
Expand All @@ -780,7 +779,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
# -------------------
# calculate loss (train step + train step end)
# -------------------
opt_closure_result = self.optimizer_closure(
opt_closure_result = self.training_step_and_backward(
split_batch,
batch_idx,
opt_idx,
Expand Down Expand Up @@ -808,13 +807,19 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
# BACKWARD PASS
# ------------------------------
# gradient update with accumulated gradients
if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0
or (self.batch_idx + 1) == self.num_training_batches):
accumulation_done = (self.batch_idx + 1) % self.accumulate_grad_batches == 0
is_final_batch = (self.batch_idx + 1) == self.num_training_batches
if accumulation_done or is_final_batch:
# hook
grad_norm_dic = self.train_loop.on_before_backward(batch_idx, optimizer)

# wrap forward + backward pass in closure for 2nd order optimizers
train_step_and_backward_closure = lambda: self.training_step_and_backward(
split_batch, batch_idx, opt_idx, optimizer, self.hiddens,
).loss

# optimizer step
self.train_loop.optimizer_step(optimizer, opt_idx, batch_idx, split_batch)
self.train_loop.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)

# hook
self.train_loop.on_before_zero_grad(optimizer)
Expand Down Expand Up @@ -843,7 +848,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
)
return result

def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens):
def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens):
"""
wrap the forward step in a closure so second order methods work
"""
Expand Down
24 changes: 11 additions & 13 deletions pytorch_lightning/trainer/training_loop_temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,21 +214,11 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
)
return result

def optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch):
# calls .step(), .zero_grad()
# override function to modify this behavior

def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure):
with self.trainer.profiler.profile('optimizer_step'):
lambda_closure = lambda: self.trainer.optimizer_closure(
split_batch,
batch_idx,
opt_idx,
optimizer,
self.trainer.hiddens,
).loss

# optimizer step lightningModule hook
self.trainer.accelerator_backend.optimizer_step(optimizer, batch_idx, opt_idx, lambda_closure)
self.trainer.accelerator_backend.optimizer_step(optimizer, batch_idx, opt_idx,
train_step_and_backward_closure)

def on_before_zero_grad(self, optimizer):
model = self.trainer.get_model()
Expand Down Expand Up @@ -280,3 +270,11 @@ def process_hiddens(self, opt_closure_result):
if isinstance(opt_closure_result.training_step_output, Result):
opt_closure_result.training_step_output_for_epoch_end.drop_hiddens()
return hiddens

def tbptt_split_batch(self, batch):
splits = [batch]
if self.trainer.truncated_bptt_steps is not None:
model_ref = self.trainer.get_model()
with self.trainer.profiler.profile('tbptt_split_batch'):
splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps)
return splits
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer_steps_dict_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_training_step_dict(tmpdir):
assert pbar_metrics['pbar_acc2'] == 19.0

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)


Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_trainer_steps_result_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_training_step_result_log_step_only(tmpdir):
assert f'step_log_acc2_b{batch_idx}' in train_step_out

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)


Expand Down Expand Up @@ -158,7 +158,7 @@ def test_training_step_result_log_epoch_only(tmpdir):
assert f'epoch_log_acc2_e{trainer.current_epoch}' in train_step_out

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)


Expand Down Expand Up @@ -293,7 +293,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
assert 'epoch_step_epoch_log_acc2' in train_step_out

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)


Expand Down Expand Up @@ -372,7 +372,7 @@ def test_training_step_epoch_end_result(tmpdir):
assert 'epoch_step_epoch_log_acc2' in train_step_out

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)


Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_trainer_steps_scalar_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_training_step_scalar(tmpdir):
assert train_step_out.item() == 171

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171


Expand Down Expand Up @@ -80,7 +80,7 @@ def training_step_scalar_with_step_end(tmpdir):
assert train_step_out.item() == 171

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171


Expand Down Expand Up @@ -127,7 +127,7 @@ def test_full_training_loop_scalar(tmpdir):
assert train_step_out.item() == 171

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171


Expand Down Expand Up @@ -170,5 +170,5 @@ def test_train_step_epoch_end_scalar(tmpdir):
assert train_step_out.item() == 171

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171