Skip to content

Commit

Permalink
ref: inner train loop (intermediate step) 15/n (#3374)
Browse files Browse the repository at this point in the history
* ref: inner train loop (intermediate step) 15/n

* ref: inner train loop (intermediate step) 15/n
  • Loading branch information
williamFalcon authored Sep 7, 2020
1 parent 7073de8 commit 60a3e28
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 68 deletions.
63 changes: 0 additions & 63 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,6 @@ def training_step(self, batch, batch_idx):
from typing import Union, List

from torch.utils.data import DataLoader
from copy import deepcopy

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
Expand Down Expand Up @@ -243,10 +240,6 @@ def process_output(self, *args):
def call_hook(self, hook_name, *args, **kwargs):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def has_arg(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accumulator, checkpoint_accumulator):
"""
Figure out what needs to be tracked/logged at the end of the epoch
Expand Down Expand Up @@ -276,24 +269,6 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu

return epoch_end_outputs

def check_checkpoint_callback(self, should_check_val):
# when no val loop is present or fast-dev-run still need to call checkpoints
# TODO bake this logic into the checkpoint callback
should_activate = not is_overridden('validation_step', self.get_model()) and not should_check_val
if should_activate:
checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
[c.on_validation_end(self, self.get_model()) for c in checkpoint_callbacks]

def update_train_loop_lr_schedulers(self, monitor_metrics=None):
if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0
or (self.batch_idx + 1) == self.num_training_batches):
# update lr
self.update_learning_rates(interval='step', monitor_metrics=monitor_metrics)

def run_on_epoch_end_hook(self):
self.call_hook('on_epoch_end')
self.call_hook('on_train_epoch_end')

def run_training_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers):
# epoch output is a list. Each item in that list has all the outputs per optimizer
# epoch_output[optimizer_idx][training_step_idx][tbptt_index]
Expand Down Expand Up @@ -415,13 +390,6 @@ def __gather_result_across_time_and_optimizers(self, epoch_output):

return gathered_epoch_outputs

def increment_accumulated_grad_global_step(self):
# progress global step according to grads progress
if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0
or (self.batch_idx + 1) == self.num_training_batches):
self.global_step += 1
self.total_batch_idx += 1

def save_train_loop_metrics_to_loggers(self, batch_idx, batch_output):
# when metrics should be logged
should_log_metrics = (batch_idx + 1) % self.row_log_interval == 0 or self.should_stop
Expand All @@ -439,17 +407,6 @@ def save_loggers_in_training_loop(self, batch_idx):
if self.is_global_zero and self.logger is not None:
self.logger.save()

def should_check_val(self, batch_idx, is_last_batch):
# decide if we should run validation
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
can_check_val = self.enable_validation and can_check_epoch
should_check_val = is_val_check_batch or self.should_stop
is_last_batch_for_infinite_dataset = (is_last_batch and self.val_check_batch == float('inf'))
should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset)

return should_check_val

def run_training_batch(self, batch, batch_idx, dataloader_idx):
# track grad norms
grad_norm_dic = {}
Expand Down Expand Up @@ -584,26 +541,6 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,

return result

def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
# enable not needing to add opt_idx to training_step
args = [batch, batch_idx]

if len(self.optimizers) > 1:
if self.has_arg('training_step', 'optimizer_idx'):
args.append(opt_idx)
else:
num_opts = len(self.optimizers)
raise ValueError(
f'Your LightningModule defines {num_opts} optimizers but '
f'training_step is missing the "optimizer_idx" argument.'
)

# pass hiddens if using tbptt
if self.truncated_bptt_steps is not None:
args.append(hiddens)

return args

def update_learning_rates(self, interval: str, monitor_metrics=None):
"""Update learning rates.
Expand Down
62 changes: 57 additions & 5 deletions pytorch_lightning/trainer/training_loop_temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):

def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
with self.trainer.profiler.profile('model_forward'):
args = self.trainer.build_train_args(split_batch, batch_idx, opt_idx, hiddens)
args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens)
training_step_output = self.trainer.accelerator_backend.training_step(args)
training_step_output = self.trainer.call_hook('training_step_end', training_step_output)

Expand Down Expand Up @@ -344,7 +344,7 @@ def run_training_epoch(self):
# -----------------------------------------
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
# -----------------------------------------
should_check_val = self.trainer.should_check_val(batch_idx, is_last_batch)
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch)
if should_check_val:
self.trainer.run_evaluation(test_mode=False)

Expand All @@ -361,10 +361,10 @@ def run_training_epoch(self):
# update LR schedulers
monitor_metrics = deepcopy(self.trainer.callback_metrics)
monitor_metrics.update(batch_output.batch_log_metrics)
self.trainer.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)
self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)

# progress global step according to grads progress
self.trainer.increment_accumulated_grad_global_step()
self.increment_accumulated_grad_global_step()

# max steps reached, end training
if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step:
Expand All @@ -388,4 +388,56 @@ def run_training_epoch(self):
self.check_checkpoint_callback(self.should_check_val)

# epoch end hook
self.trainer.run_on_epoch_end_hook()
self.run_on_epoch_end_hook()

def update_train_loop_lr_schedulers(self, monitor_metrics=None):
num_accumulated_batches_reached = (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0
num_training_batches_reached = (self.trainer.batch_idx + 1) == self.trainer.num_training_batches

if num_accumulated_batches_reached or num_training_batches_reached:
# update lr
self.trainer.update_learning_rates(interval='step', monitor_metrics=monitor_metrics)

def run_on_epoch_end_hook(self):
self.trainer.call_hook('on_epoch_end')
self.trainer.call_hook('on_train_epoch_end')

def increment_accumulated_grad_global_step(self):
num_accumulated_batches_reached = (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0
num_training_batches_reached = (self.trainer.batch_idx + 1) == self.trainer.num_training_batches

# progress global step according to grads progress
if num_accumulated_batches_reached or num_training_batches_reached:
self.trainer.global_step += 1
self.trainer.total_batch_idx += 1

def should_check_val_fx(self, batch_idx, is_last_batch):
# decide if we should run validation
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
can_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
can_check_val = self.trainer.enable_validation and can_check_epoch
should_check_val = is_val_check_batch or self.trainer.should_stop
is_last_batch_for_infinite_dataset = (is_last_batch and self.trainer.val_check_batch == float('inf'))
should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset)

return should_check_val

def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
# enable not needing to add opt_idx to training_step
args = [batch, batch_idx]

if len(self.trainer.optimizers) > 1:
if self.trainer.has_arg('training_step', 'optimizer_idx'):
args.append(opt_idx)
else:
num_opts = len(self.trainer.optimizers)
raise ValueError(
f'Your LightningModule defines {num_opts} optimizers but '
f'training_step is missing the "optimizer_idx" argument.'
)

# pass hiddens if using tbptt
if self.trainer.truncated_bptt_steps is not None:
args.append(hiddens)

return args

0 comments on commit 60a3e28

Please sign in to comment.