Skip to content

Commit

Permalink
Disable training with zero num_training_batches when insufficient lim…
Browse files Browse the repository at this point in the history
…it_train_batches (#5703)

* disable training when zero num_train_batches with limit_train_batches

* refactor train skip condition

* fix formatting issues

* fix formatting issues

* ref: test error msg

* fix tests for data loader calls

* fix train dataloader condition

* update limit_train_batches upper range in test comment

* remove model state check test

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
kaushikb11 and mergify[bot] authored Feb 3, 2021
1 parent 793fe73 commit ba04bb3
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 15 deletions.
12 changes: 3 additions & 9 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,7 @@ def num_optimizers(self):
return num_optimizers

def should_skip_training(self):
if self.trainer.current_epoch >= self.trainer.max_epochs:
return True

if self.trainer.limit_train_batches == 0:
return True

return False
return self.trainer.current_epoch >= self.trainer.max_epochs or self.trainer.num_training_batches == 0

def on_train_start(self):
# clear cache before training
Expand Down Expand Up @@ -203,7 +197,7 @@ def on_train_epoch_start(self, epoch):
model = self.trainer.get_model()

# reset train dataloader
if self.trainer.reload_dataloaders_every_epoch:
if epoch != 0 and self.trainer.reload_dataloaders_every_epoch:
self.trainer.reset_train_dataloader(model)

# set seed for distributed sampler (enables shuffling for each epoch)
Expand Down Expand Up @@ -238,7 +232,7 @@ def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx,
self.trainer.logger_connector.on_train_batch_end()

def reset_train_val_dataloaders(self, model):
if not self.trainer.reload_dataloaders_every_epoch:
if self.trainer.train_dataloader is None or not self.trainer.reload_dataloaders_every_epoch:
self.trainer.reset_train_dataloader(model)

if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch:
Expand Down
68 changes: 62 additions & 6 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import Namespace
from copy import deepcopy
import math
import os
from pathlib import Path
import pickle
import sys
from argparse import Namespace
from copy import deepcopy
from pathlib import Path
from unittest.mock import ANY, call, patch

import cloudpickle
from omegaconf import OmegaConf
import pytest
import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader

import tests.base.develop_utils as tutils
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
Expand All @@ -35,8 +37,7 @@
from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel, EvalModelTemplate
import tests.base.develop_utils as tutils
from tests.base import BoringModel, EvalModelTemplate, RandomDataset


@pytest.mark.parametrize("url_ckpt", [True, False])
Expand Down Expand Up @@ -1444,3 +1445,58 @@ def test_trainer_profiler_incorrect_arg_type(profiler):
match=r"Only None, bool, str and subclasses of `BaseProfiler`"
r" are valid values for `Trainer`'s `profiler` parameter. *"):
Trainer(profiler=profiler)


@pytest.mark.parametrize(
["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"],
[(0.2, 0, 0, 0, False), (0.5, 10, 2, 4, True)],
)
def test_disabled_training_for_insufficient_limit_train_batches(tmpdir, limit_train_batches, global_step,
num_training_batches, current_epoch, should_train):
"""
Verify when `limit_train_batches` is float & between [0.0, 1.0] and
`int(self.num_training_batches * self.limit_train_batches) == 0`, the training loop is disabled.
"""
class CurrentModel(BoringModel):

training_step_invoked = False
training_epoch_end_invoked = False

def training_step(self, *args, **kwargs):
self.training_step_invoked = True
return super().training_step(*args, **kwargs)

def training_epoch_end(self, *args, **kwargs):
self.training_epoch_end_invoked = True
return super().training_epoch_end(*args, **kwargs)

dataset_len = 100
batch_size = 25

train = RandomDataset(32, length=dataset_len)
train_loader = DataLoader(train, batch_size=batch_size)

model = CurrentModel()

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
limit_train_batches=limit_train_batches,
)
result = trainer.fit(model, train_loader)

params_string = f"""`limit_train_batches={limit_train_batches}`, `dataset_len={dataset_len}`
& `batch_size={batch_size}` as
`num_training_batches={num_training_batches}`"""
if should_train:
error_string = f"should run with {params_string}"
else:
error_string = f"should not run with {params_string}"

assert result == 1, "training failed to complete"
assert trainer.state == TrainerState.FINISHED
assert trainer.global_step == global_step
assert trainer.num_training_batches == num_training_batches
assert trainer.current_epoch == current_epoch
assert model.training_step_invoked == should_train, f"`training_step` {error_string}"
assert model.training_epoch_end_invoked == should_train, f"`training_epoch_end` {error_string}"

0 comments on commit ba04bb3

Please sign in to comment.