Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: moving train loop to own object 2/n (intermediate steps) #3313

Merged
merged 2 commits into from
Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions pytorch_lightning/trainer/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,25 @@ class DataConnector(object):
def __init__(self, trainer):
self.trainer = trainer

def get_profiled_train_dataloader(self, train_dataloader):
profiled_dl = self.trainer.profiler.profile_iterable(
enumerate(self._with_is_last(train_dataloader)),
"get_train_batch"
)
return profiled_dl

def _with_is_last(self, iterable):
"""Pass through values from the given iterable with an added boolean indicating if this is the last item.
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_"""
it = iter(iterable)
last = next(it)
for val in it:
# yield last and has next
yield last, False
last = val
# yield last, no longer has next
yield last, True

def prepare_data(self, model):
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
Expand Down
42 changes: 10 additions & 32 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer.training_loop_temp import TrainLoop
from pytorch_lightning.trainer.data_connector import DataConnector

try:
from apex import amp
Expand Down Expand Up @@ -264,6 +265,7 @@ class TrainerTrainLoopMixin(ABC):
accelerator_backend: ...
val_dataloaders: ...
train_loop: TrainLoop
data_connector: DataConnector

# Callback system
callbacks: List[Callback]
Expand Down Expand Up @@ -443,10 +445,10 @@ def run_training_epoch(self):
# track epoch output
epoch_output = [[] for _ in range(self.train_loop.num_optimizers)]

# run epoch
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
):
# enable profiling for the dataloader
train_dataloader = self.data_connector.get_profiled_train_dataloader(train_dataloader)
dataloader_idx = 0
for batch_idx, (batch, is_last_batch) in train_dataloader:
# stop epoch if we limited the number of training batches
if batch_idx >= self.num_training_batches:
break
Expand All @@ -457,7 +459,7 @@ def run_training_epoch(self):
# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
batch_output = self.run_training_batch(batch, batch_idx)
batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)

# only track outputs when user implements training_epoch_end
# otherwise we will build up unnecessary memory
Expand All @@ -467,12 +469,8 @@ def run_training_epoch(self):
self.train_loop.checkpoint_accumulator
)

# track the outputs to reduce at the end of the epoch
for opt_idx, opt_outputs in enumerate(epoch_end_outputs):
# with 1 step (no tbptt) don't use a sequence at epoch end
if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result):
opt_outputs = opt_outputs[0]
epoch_output[opt_idx].append(opt_outputs)
# hook
self.train_loop.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx)

# when returning -1 from train_step, we end epoch early
self.should_stop = batch_output.signal == -1
Expand Down Expand Up @@ -748,7 +746,7 @@ def should_check_val(self, batch_idx, is_last_batch):

return should_check_val

def run_training_batch(self, batch, batch_idx):
def run_training_batch(self, batch, batch_idx, dataloader_idx):
# track grad norms
grad_norm_dic = {}

Expand All @@ -767,7 +765,6 @@ def run_training_batch(self, batch, batch_idx):
return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic)

# hook
dataloader_idx = 0
response = self.call_hook('on_batch_start')
if response == -1:
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)
Expand Down Expand Up @@ -859,12 +856,6 @@ def run_training_batch(self, batch, batch_idx):
# reset for next set of accumulated grads
self.batch_loss_value.reset()

# hook
self.call_hook('on_batch_end')

# hook
self.call_hook('on_train_batch_end', batch, batch_idx, dataloader_idx)

# collapse all metrics into one dict
batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()}

Expand Down Expand Up @@ -1186,16 +1177,3 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
scheduler_idx,
old_lr, new_lr
)


def _with_is_last(iterable):
"""Pass through values from the given iterable with an added boolean indicating if this is the last item.
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_"""
it = iter(iterable)
last = next(it)
for val in it:
# yield last and has next
yield last, False
last = val
# yield last, no longer has next
yield last, True
18 changes: 18 additions & 0 deletions pytorch_lightning/trainer/training_loop_temp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pytorch_lightning.trainer.supporters import Accumulator
import numpy as np
from pytorch_lightning.core.step_result import Result


class TrainLoop:
Expand Down Expand Up @@ -27,6 +28,23 @@ def on_train_epoch_start(self):
self.early_stopping_accumulator = Accumulator()
self.checkpoint_accumulator = Accumulator()

def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx):
# figure out what to track for epoch end
self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs)

# hook
self.trainer.call_hook('on_batch_end')
self.trainer.call_hook('on_train_batch_end', batch, batch_idx, dataloader_idx)

def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs):
# track the outputs to reduce at the end of the epoch
for opt_idx, opt_outputs in enumerate(epoch_end_outputs):
# with 1 step (no tbptt) don't use a sequence at epoch end
if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result):
opt_outputs = opt_outputs[0]
epoch_output[opt_idx].append(opt_outputs)


def get_optimizers_iterable(self):
"""
Generates an iterable with (idx, optimizer) for each optimizer.
Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_trainer_steps_dict_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_training_step_dict(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0
Expand Down Expand Up @@ -76,7 +76,7 @@ def training_step_with_step_end(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 14.0
assert out.batch_log_metrics['log_acc2'] == 9.0
Expand Down Expand Up @@ -117,7 +117,7 @@ def test_full_training_loop_dict(tmpdir):
# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 14.0
assert out.batch_log_metrics['log_acc2'] == 9.0
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_train_step_epoch_end(tmpdir):
# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0
Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_trainer_steps_result_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_training_step_result_log_step_only(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert out.batch_log_metrics[f'step_log_and_pbar_acc1_b{batch_idx}'] == 11.0
assert out.batch_log_metrics[f'step_log_acc2_b{batch_idx}'] == 12.0
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_training_step_result_log_epoch_only(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0

Expand Down Expand Up @@ -277,7 +277,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 2

Expand Down Expand Up @@ -356,7 +356,7 @@ def test_training_step_epoch_end_result(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 2

Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_trainer_steps_scalar_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_training_step_scalar(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
Expand Down Expand Up @@ -68,7 +68,7 @@ def training_step_scalar_with_step_end(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_full_training_loop_scalar(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_train_step_epoch_end_scalar(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
Expand Down