Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable training with zero num_training_batches when insufficient limit_train_batches #5703

Merged
merged 13 commits into from
Feb 3, 2021
12 changes: 3 additions & 9 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,7 @@ def num_optimizers(self):
return num_optimizers

def should_skip_training(self):
if self.trainer.current_epoch >= self.trainer.max_epochs:
return True

if self.trainer.limit_train_batches == 0:
return True

return False
return self.trainer.current_epoch >= self.trainer.max_epochs or self.trainer.num_training_batches == 0

def on_train_start(self):
# clear cache before training
Expand Down Expand Up @@ -203,7 +197,7 @@ def on_train_epoch_start(self, epoch):
model = self.trainer.get_model()

# reset train dataloader
if self.trainer.reload_dataloaders_every_epoch:
if epoch != 0 and self.trainer.reload_dataloaders_every_epoch:
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
self.trainer.reset_train_dataloader(model)

# set seed for distributed sampler (enables shuffling for each epoch)
Expand Down Expand Up @@ -238,7 +232,7 @@ def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx,
self.trainer.logger_connector.on_train_batch_end()

def reset_train_val_dataloaders(self, model):
if not self.trainer.reload_dataloaders_every_epoch:
if self.trainer.train_dataloader is None or not self.trainer.reload_dataloaders_every_epoch:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
self.trainer.reset_train_dataloader(model)

if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch:
Expand Down
68 changes: 62 additions & 6 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import Namespace
from copy import deepcopy
import math
import os
from pathlib import Path
import pickle
import sys
from argparse import Namespace
from copy import deepcopy
from pathlib import Path
from unittest.mock import ANY, call, patch

import cloudpickle
from omegaconf import OmegaConf
import pytest
import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader

import tests.base.develop_utils as tutils
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
Expand All @@ -35,8 +37,7 @@
from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel, EvalModelTemplate
import tests.base.develop_utils as tutils
from tests.base import BoringModel, EvalModelTemplate, RandomDataset


@pytest.mark.parametrize("url_ckpt", [True, False])
Expand Down Expand Up @@ -1444,3 +1445,58 @@ def test_trainer_profiler_incorrect_arg_type(profiler):
match=r"Only None, bool, str and subclasses of `BaseProfiler`"
r" are valid values for `Trainer`'s `profiler` parameter. *"):
Trainer(profiler=profiler)


@pytest.mark.parametrize(
["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"],
[(0.2, 0, 0, 0, False), (0.5, 10, 2, 4, True)],
)
def test_disabled_training_for_insufficient_limit_train_batches(tmpdir, limit_train_batches, global_step,
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
num_training_batches, current_epoch, should_train):
"""
Verify when `limit_train_batches` is float & between [0.0, 1.0] and
`int(self.num_training_batches * self.limit_train_batches) == 0`, the training loop is disabled.
"""
class CurrentModel(BoringModel):

training_step_invoked = False
training_epoch_end_invoked = False

def training_step(self, *args, **kwargs):
self.training_step_invoked = True
return super().training_step(*args, **kwargs)

def training_epoch_end(self, *args, **kwargs):
self.training_epoch_end_invoked = True
return super().training_epoch_end(*args, **kwargs)

dataset_len = 100
batch_size = 25

train = RandomDataset(32, length=dataset_len)
train_loader = DataLoader(train, batch_size=batch_size)

model = CurrentModel()

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
limit_train_batches=limit_train_batches,
)
result = trainer.fit(model, train_loader)

params_string = f"""`limit_train_batches={limit_train_batches}`, `dataset_len={dataset_len}`
& `batch_size={batch_size}` as
`num_training_batches={num_training_batches}`"""
if should_train:
error_string = f"should run with {params_string}"
else:
error_string = f"should not run with {params_string}"
carmocca marked this conversation as resolved.
Show resolved Hide resolved

assert result == 1, "training failed to complete"
assert trainer.state == TrainerState.FINISHED
assert trainer.global_step == global_step
assert trainer.num_training_batches == num_training_batches
assert trainer.current_epoch == current_epoch
assert model.training_step_invoked == should_train, f"`training_step` {error_string}"
assert model.training_epoch_end_invoked == should_train, f"`training_epoch_end` {error_string}"