Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added max/min number of steps in Trainer #728

Merged
merged 25 commits into from
Feb 18, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5d7c59e
Added max number of steps in Trainer
Jan 22, 2020
47b4119
Added docstring
Jan 28, 2020
00c3738
Fix flake8 errors
Jan 28, 2020
e468f31
Clarified docstrings
Jan 30, 2020
62b0b08
Fixed flake8 error
peteriz Jan 30, 2020
56fc410
Added min_steps to Trainer
Feb 2, 2020
ea66c7e
Merge branch 'master' of https://github.com/peteriz/pytorch-lightning
Feb 2, 2020
b9e91d4
Added steps and epochs test
Feb 3, 2020
203a441
flake8
Feb 3, 2020
9007772
minor fix
peteriz Feb 3, 2020
9d6590a
fix steps test in test_trainer
Feb 3, 2020
5af168e
Merge branch 'master' of https://github.com/peteriz/pytorch-lightning
Feb 3, 2020
167322f
Merge branch 'master' of https://github.com/peteriz/pytorch-lightning
Feb 3, 2020
242af5e
Merge branch 'master' of https://github.com/peteriz/pytorch-lightning
Feb 3, 2020
24d967e
Split steps test into 2 tests
Feb 9, 2020
b967119
Refactor steps test
Feb 11, 2020
3440196
Update test_trainer.py
peteriz Feb 11, 2020
9ac7f69
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Feb 13, 2020
ec71b4f
Minor in test_trainer.py
Feb 13, 2020
4e7adfa
Update test_trainer.py
williamFalcon Feb 15, 2020
075fb46
Merge branch 'master' into master
Feb 16, 2020
927d18f
Address PR comments
Feb 16, 2020
6903788
Merge branch 'master' of https://github.com/peteriz/pytorch-lightning
Feb 16, 2020
c9285bb
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Feb 18, 2020
b66703c
Minor
Feb 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
max_epochs=1000,
min_epochs=1,
max_steps=None,
peteriz marked this conversation as resolved.
Show resolved Hide resolved
train_percent_check=1.0,
val_percent_check=1.0,
test_percent_check=1.0,
Expand Down Expand Up @@ -289,6 +290,15 @@ def __init__(
.. deprecated:: 0.5.0
Use `min_nb_epochs` instead. Will remove 0.8.0.

max_steps (int): Stop training after this number of steps.
Example::
peteriz marked this conversation as resolved.
Show resolved Hide resolved

# default used by the Trainer (disabled)
peteriz marked this conversation as resolved.
Show resolved Hide resolved
trainer = Trainer(max_steps=None)

# Stop after 100 steps (batches)
trainer = Trainer(max_steps=100)

train_percent_check (int): How much of training dataset to check.
Useful when debugging or testing something that happens at the end of an epoch.
Example::
Expand Down Expand Up @@ -505,6 +515,8 @@ def __init__(
min_epochs = min_nb_epochs
self.min_epochs = min_epochs

self.max_steps = max_steps

# Backward compatibility
if nb_sanity_val_steps is not None:
warnings.warn("`nb_sanity_val_steps` has renamed to `num_sanity_val_steps` since v0.5.0"
Expand Down
9 changes: 9 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,11 @@ def train(self):
raise MisconfigurationException(m)
self.reduce_lr_on_plateau_scheduler.step(val_loss, epoch=self.current_epoch)

if self.max_steps is not None and self.max_steps == self.global_step:
peteriz marked this conversation as resolved.
Show resolved Hide resolved
self.main_progress_bar.close()
model.on_train_end()
return

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
if (self.enable_early_stop and not self.disable_validation and is_val_epoch and
Expand Down Expand Up @@ -421,6 +426,10 @@ def run_training_epoch(self):
self.global_step += 1
self.total_batch_idx += 1

# max steps reached, end training
if self.max_steps is not None and self.max_steps == self.global_step:
break

# end epoch early
# stop when the flag is changed or we've gone past the amount
# requested in the batches
Expand Down