Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added max/min number of steps in Trainer #728

Merged
merged 25 commits into from
Feb 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5d7c59e
Added max number of steps in Trainer
Jan 22, 2020
47b4119
Added docstring
Jan 28, 2020
00c3738
Fix flake8 errors
Jan 28, 2020
e468f31
Clarified docstrings
Jan 30, 2020
62b0b08
Fixed flake8 error
peteriz Jan 30, 2020
56fc410
Added min_steps to Trainer
Feb 2, 2020
ea66c7e
Merge branch 'master' of https://github.com/peteriz/pytorch-lightning
Feb 2, 2020
b9e91d4
Added steps and epochs test
Feb 3, 2020
203a441
flake8
Feb 3, 2020
9007772
minor fix
peteriz Feb 3, 2020
9d6590a
fix steps test in test_trainer
Feb 3, 2020
5af168e
Merge branch 'master' of https://github.com/peteriz/pytorch-lightning
Feb 3, 2020
167322f
Merge branch 'master' of https://github.com/peteriz/pytorch-lightning
Feb 3, 2020
242af5e
Merge branch 'master' of https://github.com/peteriz/pytorch-lightning
Feb 3, 2020
24d967e
Split steps test into 2 tests
Feb 9, 2020
b967119
Refactor steps test
Feb 11, 2020
3440196
Update test_trainer.py
peteriz Feb 11, 2020
9ac7f69
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Feb 13, 2020
ec71b4f
Minor in test_trainer.py
Feb 13, 2020
4e7adfa
Update test_trainer.py
williamFalcon Feb 15, 2020
075fb46
Merge branch 'master' into master
Feb 16, 2020
927d18f
Address PR comments
Feb 16, 2020
6903788
Merge branch 'master' of https://github.com/peteriz/pytorch-lightning
Feb 16, 2020
c9285bb
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Feb 18, 2020
b66703c
Minor
Feb 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def __init__(
min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
max_epochs=1000,
min_epochs=1,
max_steps=None,
peteriz marked this conversation as resolved.
Show resolved Hide resolved
min_steps=None,
train_percent_check=1.0,
val_percent_check=1.0,
test_percent_check=1.0,
Expand Down Expand Up @@ -345,6 +347,20 @@ def __init__(
.. warning:: .. deprecated:: 0.5.0
Use `min_nb_epochs` instead. Will remove 0.8.0.

max_steps (int): Stop training after this number of steps. Disabled by default (None).
Training will stop if max_steps or max_epochs have reached (earliest).
Example::
peteriz marked this conversation as resolved.
Show resolved Hide resolved

# Stop after 100 steps
trainer = Trainer(max_steps=100)

min_steps(int): Force training for at least these number of steps. Disabled by default (None).
Trainer will train model for at least min_steps or min_epochs (latest).
Example::

# Run at least for 100 steps (disable min_epochs)
trainer = Trainer(min_steps=100, min_epochs=0)

train_percent_check (int): How much of training dataset to check.
Useful when debugging or testing something that happens at the end of an epoch.
Example::
Expand Down Expand Up @@ -610,6 +626,9 @@ def __init__(
min_epochs = min_nb_epochs
self.min_epochs = min_epochs

self.max_steps = max_steps
self.min_steps = min_steps

# Backward compatibility
if nb_sanity_val_steps is not None:
warnings.warn("`nb_sanity_val_steps` has renamed to `num_sanity_val_steps` since v0.5.0"
Expand Down
13 changes: 12 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,17 @@ def train(self):
raise MisconfigurationException(m)
self.reduce_lr_on_plateau_scheduler.step(val_loss, epoch=self.current_epoch)

if self.max_steps and self.max_steps == self.global_step:
self.main_progress_bar.close()
model.on_train_end()
return

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

if (self.enable_early_stop and not self.disable_validation and is_val_epoch and
(met_min_epochs or self.fast_dev_run)):
((met_min_epochs and met_min_steps) or self.fast_dev_run)):
Borda marked this conversation as resolved.
Show resolved Hide resolved
should_stop = self.early_stop_callback.on_epoch_end()
# stop training
stop = should_stop and met_min_epochs
Expand Down Expand Up @@ -463,6 +470,10 @@ def run_training_epoch(self):
self.global_step += 1
self.total_batch_idx += 1

# max steps reached, end training
if self.max_steps is not None and self.max_steps == self.global_step:
break

# end epoch early
# stop when the flag is changed or we've gone past the amount
# requested in the batches
Expand Down
86 changes: 86 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import os

import pytest
Expand All @@ -6,6 +7,7 @@
import tests.models.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import (
EarlyStopping,
ModelCheckpoint,
)
from tests.models import (
Expand Down Expand Up @@ -447,5 +449,89 @@ class CurrentTestModel(
trainer.test()


def _init_steps_model():
"""private method for initializing a model with 5% train epochs"""
tutils.reset_seed()
peteriz marked this conversation as resolved.
Show resolved Hide resolved
model, _ = tutils.get_model()

# define train epoch to 5% of data
train_percent = 0.05
# get number of samples in 1 epoch
num_train_samples = math.floor(len(model.train_dataloader()) * train_percent)

trainer_options = dict(
train_percent_check=train_percent,
)
return model, trainer_options, num_train_samples


def test_trainer_max_steps_and_epochs(tmpdir):
"""Verify model trains according to specified max steps"""
model, trainer_options, num_train_samples = _init_steps_model()

# define less train steps than epochs
trainer_options.update(dict(
max_epochs=5,
max_steps=num_train_samples + 10
))

# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1, "Training did not complete"

# check training stopped at max_steps
assert trainer.global_step == trainer.max_steps, "Model did not stop at max_steps"

# define less train epochs than steps
trainer_options['max_epochs'] = 2
trainer_options['max_steps'] = trainer_options['max_epochs'] * 2 * num_train_samples

# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1, "Training did not complete"

# check training stopped at max_epochs
assert trainer.global_step == num_train_samples * trainer.max_nb_epochs \
and trainer.current_epoch == trainer.max_nb_epochs - 1, "Model did not stop at max_epochs"


def test_trainer_min_steps_and_epochs(tmpdir):
"""Verify model trains according to specified min steps"""
model, trainer_options, num_train_samples = _init_steps_model()

# define callback for stopping the model and default epochs
trainer_options.update({
'early_stop_callback': EarlyStopping(monitor='val_loss', min_delta=1.0),
'val_check_interval': 20,
'min_epochs': 1,
'max_epochs': 10
})

# define less min steps than 1 epoch
trainer_options['min_steps'] = math.floor(num_train_samples / 2)

# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1, "Training did not complete"

# check model ran for at least min_epochs
assert trainer.global_step >= num_train_samples and \
trainer.current_epoch > 0, "Model did not train for at least min_epochs"

# define less epochs than min_steps
trainer_options['min_steps'] = math.floor(num_train_samples * 1.5)

# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1, "Training did not complete"

# check model ran for at least num_train_samples*1.5
assert trainer.global_step >= math.floor(num_train_samples * 1.5) and \
trainer.current_epoch > 0, "Model did not train for at least min_steps"

# if __name__ == '__main__':
# pytest.main([__file__])