diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d2395ee566b0c..ddcfd5e0c9629 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -83,6 +83,8 @@ def __init__( min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0 max_epochs=1000, min_epochs=1, + max_steps=None, + min_steps=None, train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0, @@ -345,6 +347,20 @@ def __init__( .. warning:: .. deprecated:: 0.5.0 Use `min_nb_epochs` instead. Will remove 0.8.0. + max_steps (int): Stop training after this number of steps. Disabled by default (None). + Training will stop if max_steps or max_epochs have reached (earliest). + Example:: + + # Stop after 100 steps + trainer = Trainer(max_steps=100) + + min_steps(int): Force training for at least these number of steps. Disabled by default (None). + Trainer will train model for at least min_steps or min_epochs (latest). + Example:: + + # Run at least for 100 steps (disable min_epochs) + trainer = Trainer(min_steps=100, min_epochs=0) + train_percent_check (int): How much of training dataset to check. Useful when debugging or testing something that happens at the end of an epoch. Example:: @@ -610,6 +626,9 @@ def __init__( min_epochs = min_nb_epochs self.min_epochs = min_epochs + self.max_steps = max_steps + self.min_steps = min_steps + # Backward compatibility if nb_sanity_val_steps is not None: warnings.warn("`nb_sanity_val_steps` has renamed to `num_sanity_val_steps` since v0.5.0" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d8bdb3f49c585..34c5ee0473121 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -372,10 +372,17 @@ def train(self): raise MisconfigurationException(m) self.reduce_lr_on_plateau_scheduler.step(val_loss, epoch=self.current_epoch) + if self.max_steps and self.max_steps == self.global_step: + self.main_progress_bar.close() + model.on_train_end() + return + # early stopping met_min_epochs = epoch >= self.min_epochs - 1 + met_min_steps = self.global_step >= self.min_steps if self.min_steps else True + if (self.enable_early_stop and not self.disable_validation and is_val_epoch and - (met_min_epochs or self.fast_dev_run)): + ((met_min_epochs and met_min_steps) or self.fast_dev_run)): should_stop = self.early_stop_callback.on_epoch_end() # stop training stop = should_stop and met_min_epochs @@ -463,6 +470,10 @@ def run_training_epoch(self): self.global_step += 1 self.total_batch_idx += 1 + # max steps reached, end training + if self.max_steps is not None and self.max_steps == self.global_step: + break + # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 06cd8a7de93cd..2b064357ea7de 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -1,3 +1,4 @@ +import math import os import pytest @@ -6,6 +7,7 @@ import tests.models.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ( + EarlyStopping, ModelCheckpoint, ) from tests.models import ( @@ -447,5 +449,89 @@ class CurrentTestModel( trainer.test() +def _init_steps_model(): + """private method for initializing a model with 5% train epochs""" + tutils.reset_seed() + model, _ = tutils.get_model() + + # define train epoch to 5% of data + train_percent = 0.05 + # get number of samples in 1 epoch + num_train_samples = math.floor(len(model.train_dataloader()) * train_percent) + + trainer_options = dict( + train_percent_check=train_percent, + ) + return model, trainer_options, num_train_samples + + +def test_trainer_max_steps_and_epochs(tmpdir): + """Verify model trains according to specified max steps""" + model, trainer_options, num_train_samples = _init_steps_model() + + # define less train steps than epochs + trainer_options.update(dict( + max_epochs=5, + max_steps=num_train_samples + 10 + )) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + assert result == 1, "Training did not complete" + + # check training stopped at max_steps + assert trainer.global_step == trainer.max_steps, "Model did not stop at max_steps" + + # define less train epochs than steps + trainer_options['max_epochs'] = 2 + trainer_options['max_steps'] = trainer_options['max_epochs'] * 2 * num_train_samples + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + assert result == 1, "Training did not complete" + + # check training stopped at max_epochs + assert trainer.global_step == num_train_samples * trainer.max_nb_epochs \ + and trainer.current_epoch == trainer.max_nb_epochs - 1, "Model did not stop at max_epochs" + + +def test_trainer_min_steps_and_epochs(tmpdir): + """Verify model trains according to specified min steps""" + model, trainer_options, num_train_samples = _init_steps_model() + + # define callback for stopping the model and default epochs + trainer_options.update({ + 'early_stop_callback': EarlyStopping(monitor='val_loss', min_delta=1.0), + 'val_check_interval': 20, + 'min_epochs': 1, + 'max_epochs': 10 + }) + + # define less min steps than 1 epoch + trainer_options['min_steps'] = math.floor(num_train_samples / 2) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + assert result == 1, "Training did not complete" + + # check model ran for at least min_epochs + assert trainer.global_step >= num_train_samples and \ + trainer.current_epoch > 0, "Model did not train for at least min_epochs" + + # define less epochs than min_steps + trainer_options['min_steps'] = math.floor(num_train_samples * 1.5) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + assert result == 1, "Training did not complete" + + # check model ran for at least num_train_samples*1.5 + assert trainer.global_step >= math.floor(num_train_samples * 1.5) and \ + trainer.current_epoch > 0, "Model did not train for at least min_steps" + # if __name__ == '__main__': # pytest.main([__file__])