Skip to content

Commit

Permalink
better raise an error
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Apr 27, 2022
1 parent 1b7833d commit 960e731
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 53 deletions.
2 changes: 1 addition & 1 deletion tests/trainer/connectors/test_data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def val_dataloader(self):
assert trainer.val_dataloaders[0].sampler.shuffle == shuffle


def test_error_raised_with_float_limit_train_dataloader():
def test_error_raised_with_insufficient_float_limit_train_dataloader():
batch_size = 16
dl = DataLoader(RandomDataset(32, batch_size * 9), batch_size=batch_size)
trainer = Trainer(limit_train_batches=0.1, limit_val_batches=0.1)
Expand Down
52 changes: 0 additions & 52 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,58 +1416,6 @@ def test_predict_return_predictions_cpu(return_predictions, precision, tmpdir):
assert preds[0].dtype == (torch.float64 if precision == 64 else torch.float32)


@pytest.mark.parametrize(
["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"],
[(0.2, 0, 0, 0, False), (0.5, 10, 2, 5, True)],
)
def test_disabled_training_for_insufficient_limit_train_batches(
tmpdir, limit_train_batches, global_step, num_training_batches, current_epoch, should_train
):
"""Verify when `limit_train_batches` is float & between [0.0, 1.0] and.
`int(self.num_training_batches * self.limit_train_batches) == 0`, the training loop is disabled.
"""

class CurrentModel(BoringModel):

training_step_invoked = False
training_epoch_end_invoked = False

def training_step(self, *args, **kwargs):
self.training_step_invoked = True
return super().training_step(*args, **kwargs)

def training_epoch_end(self, *args, **kwargs):
self.training_epoch_end_invoked = True
return super().training_epoch_end(*args, **kwargs)

dataset_len = 100
batch_size = 25

train = RandomDataset(32, length=dataset_len)
train_loader = DataLoader(train, batch_size=batch_size)

model = CurrentModel()

trainer = Trainer(default_root_dir=tmpdir, max_epochs=5, limit_train_batches=limit_train_batches)
trainer.fit(model, train_loader)

params_string = f"""`limit_train_batches={limit_train_batches}`, `dataset_len={dataset_len}`
& `batch_size={batch_size}` as
`num_training_batches={num_training_batches}`"""
if should_train:
error_string = f"should run with {params_string}"
else:
error_string = f"should not run with {params_string}"

assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.global_step == global_step
assert trainer.num_training_batches == num_training_batches
assert trainer.current_epoch == current_epoch
assert model.training_step_invoked == should_train, f"`training_step` {error_string}"
assert model.training_epoch_end_invoked == should_train, f"`training_epoch_end` {error_string}"


@pytest.mark.parametrize(["max_steps", "max_epochs", "global_step"], [(10, 5, 10), (20, None, 20)])
def test_repeated_fit_calls_with_max_epochs_and_steps(tmpdir, max_steps, max_epochs, global_step):
"""Ensure that the training loop is bound by `max_steps` and `max_epochs` for repeated calls of `trainer.fit`,
Expand Down

0 comments on commit 960e731

Please sign in to comment.