Skip to content

Commit

Permalink
Fix val_check_interval with fast_dev_run (#5540)
Browse files Browse the repository at this point in the history
* fix val_check_interval with fast_dev_run

* chlog
  • Loading branch information
rohitgr7 authored and Borda committed Feb 4, 2021
1 parent cd11fee commit 8f9ceb9
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 19 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `reinit_scheduler_properties` with correct optimizer ([#5519](https://github.com/PyTorchLightning/pytorch-lightning/pull/5519))


- Fixed `val_check_interval` with `fast_dev_run` ([#5540](https://github.com/PyTorchLightning/pytorch-lightning/pull/5540))


## [1.1.4] - 2021-01-12

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def on_init_start(
self.trainer.max_steps = fast_dev_run
self.trainer.num_sanity_val_steps = 0
self.trainer.max_epochs = 1
self.trainer.val_check_interval = 1.0
val_check_interval = 1.0
self.trainer.check_val_every_n_epoch = 1
self.trainer.logger = DummyLogger()

Expand Down
59 changes: 41 additions & 18 deletions tests/trainer/flags/test_fast_dev_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,59 @@ def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run):
class FastDevRunModel(BoringModel):
def __init__(self):
super().__init__()
self.training_step_called = False
self.validation_step_called = False
self.test_step_called = False
self.training_step_call_count = 0
self.training_epoch_end_call_count = 0
self.validation_step_call_count = 0
self.validation_epoch_end_call_count = 0
self.test_step_call_count = 0

def training_step(self, batch, batch_idx):
self.log('some_metric', torch.tensor(7.))
self.logger.experiment.dummy_log('some_distribution', torch.randn(7) + batch_idx)
self.training_step_called = True
self.training_step_call_count += 1
return super().training_step(batch, batch_idx)

def training_epoch_end(self, outputs):
self.training_epoch_end_call_count += 1
super().training_epoch_end(outputs)

def validation_step(self, batch, batch_idx):
self.validation_step_called = True
self.validation_step_call_count += 1
return super().validation_step(batch, batch_idx)

def validation_epoch_end(self, outputs):
self.validation_epoch_end_call_count += 1
super().validation_epoch_end(outputs)

def test_step(self, batch, batch_idx):
self.test_step_call_count += 1
return super().test_step(batch, batch_idx)

checkpoint_callback = ModelCheckpoint()
early_stopping_callback = EarlyStopping()
trainer_config = dict(
fast_dev_run=fast_dev_run,
val_check_interval=2,
logger=True,
log_every_n_steps=1,
callbacks=[checkpoint_callback, early_stopping_callback],
)

def _make_fast_dev_run_assertions(trainer):
def _make_fast_dev_run_assertions(trainer, model):
# check the call count for train/val/test step/epoch
assert model.training_step_call_count == fast_dev_run
assert model.training_epoch_end_call_count == 1
assert model.validation_step_call_count == 0 if model.validation_step is None else fast_dev_run
assert model.validation_epoch_end_call_count == 0 if model.validation_step is None else 1
assert model.test_step_call_count == fast_dev_run

# check trainer arguments
assert trainer.max_steps == fast_dev_run
assert trainer.num_sanity_val_steps == 0
assert trainer.max_epochs == 1
assert trainer.val_check_interval == 1.0
assert trainer.check_val_every_n_epoch == 1

# there should be no logger with fast_dev_run
assert isinstance(trainer.logger, DummyLogger)
assert len(trainer.dev_debugger.logged_metrics) == fast_dev_run
Expand All @@ -77,13 +106,10 @@ def _make_fast_dev_run_assertions(trainer):
train_val_step_model = FastDevRunModel()
trainer = Trainer(**trainer_config)
trainer.fit(train_val_step_model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
trainer.test(ckpt_path=None)

# make sure both training_step and validation_step were called
assert train_val_step_model.training_step_called
assert train_val_step_model.validation_step_called

_make_fast_dev_run_assertions(trainer)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
_make_fast_dev_run_assertions(trainer, train_val_step_model)

# -----------------------
# also called once with no val step
Expand All @@ -93,10 +119,7 @@ def _make_fast_dev_run_assertions(trainer):

trainer = Trainer(**trainer_config)
trainer.fit(train_step_only_model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
trainer.test(ckpt_path=None)

# make sure only training_step was called
assert train_step_only_model.training_step_called
assert not train_step_only_model.validation_step_called

_make_fast_dev_run_assertions(trainer)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
_make_fast_dev_run_assertions(trainer, train_step_only_model)

0 comments on commit 8f9ceb9

Please sign in to comment.