Skip to content

Commit

Permalink
fix the mess after rebasing
Browse files Browse the repository at this point in the history
  • Loading branch information
hadim committed Feb 26, 2020
1 parent 85f47f4 commit 15a914c
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 51 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
self.callbacks: list[Callback] = []
self.get_model: Callable = None
self.get_model: Callable = ...

def on_init_start(self):
"""Called when the trainer initialization begins."""
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ def __init__(self):
self.progress_bar_refresh_rate = None

# Callback system
self.on_validation_start: Callable = None
self.on_validation_end: Callable = None
self.on_test_start: Callable = None
self.on_test_end: Callable = None
self.on_validation_start: Callable = ...
self.on_validation_end: Callable = ...
self.on_test_start: Callable = ...
self.on_test_end: Callable = ...

@abstractmethod
def copy_trainer_model_properties(self, model):
Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,12 @@ def __init__(self):
# Callback system
self.callbacks: list[Callback] = []
self.max_steps = None
self.on_train_start: Callable = None
self.on_train_end: Callable = None
self.on_batch_start: Callable = None
self.on_batch_end: Callable = None
self.on_epoch_start: Callable = None
self.on_epoch_end: Callable = None
self.on_train_start: Callable = ...
self.on_train_end: Callable = ...
self.on_batch_start: Callable = ...
self.on_batch_end: Callable = ...
self.on_epoch_start: Callable = ...
self.on_epoch_end: Callable = ...

@property
def max_nb_epochs(self):
Expand Down
104 changes: 64 additions & 40 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
LightValStepFitSingleDataloaderMixin,
LightTrainDataloader,
LightTestDataloader,
LightValidationMixin,
LightTestMixin
)
from pytorch_lightning.core.lightning import load_hparams_from_tags_csv
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning import Callback


def test_no_val_module(tmpdir):
Expand Down Expand Up @@ -792,15 +795,15 @@ def test_benchmark_option(tmpdir):
tutils.reset_seed()

class CurrentTestModel(
LightningValidationMultipleDataloadersMixin,
LightningTestModelBase
LightValidationMultipleDataloadersMixin,
LightTrainDataloader,
TestModelBase
):
pass

hparams = tutils.get_hparams()
model = CurrentTestModel(hparams)

<<<<<<< HEAD
# verify torch.backends.cudnn.benchmark is not turned on
assert not torch.backends.cudnn.benchmark

Expand All @@ -820,7 +823,53 @@ class CurrentTestModel(

# verify torch.backends.cudnn.benchmark is not turned off
assert torch.backends.cudnn.benchmark
=======


def test_testpass_overrides(tmpdir):
hparams = tutils.get_hparams()

class LocalModel(LightTrainDataloader, TestModelBase):
pass

class LocalModelNoEnd(LightTrainDataloader, LightTestDataloader, LightEmptyTestStep, TestModelBase):
pass

class LocalModelNoStep(LightTrainDataloader, TestModelBase):
def test_end(self, outputs):
return {}

# Misconfig when neither test_step or test_end is implemented
with pytest.raises(MisconfigurationException):
model = LocalModel(hparams)
Trainer().test(model)

# Misconfig when neither test_step or test_end is implemented
with pytest.raises(MisconfigurationException):
model = LocalModelNoStep(hparams)
Trainer().test(model)

# No exceptions when one or both of test_step or test_end are implemented
model = LocalModelNoEnd(hparams)
Trainer().test(model)

model = LightningTestModel(hparams)
Trainer().test(model)


def test_trainer_callback_system(tmpdir):
"""Test the callback system."""

class CurrentTestModel(
LightTrainDataloader,
LightTestMixin,
LightValidationMixin,
TestModelBase,
):
pass

hparams = tutils.get_hparams()
model = CurrentTestModel(hparams)

class TestCallback(Callback):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -880,46 +929,30 @@ def on_test_start(self, trainer, pl_module):

def on_test_end(self, trainer, pl_module):
self.on_test_end_called = True
>>>>>>> Add trainer and pl_module args to callback methods

test_callback = TestCallback()

def test_testpass_overrides(tmpdir):
hparams = tutils.get_hparams()
trainer_options = {}
trainer_options['callbacks'] = [test_callback]
trainer_options['max_epochs'] = 1
trainer_options['val_percent_check'] = 0.1
trainer_options['train_percent_check'] = 0.2
trainer_options['show_progress_bar'] = False

<<<<<<< HEAD
class LocalModel(LightTrainDataloader, TestModelBase):
pass
=======
assert not test_callback.on_init_start_called
assert not test_callback.on_init_end_called
>>>>>>> Switch to on_.*_start()

class LocalModelNoEnd(LightTrainDataloader, LightTestDataloader, LightEmptyTestStep, TestModelBase):
pass
# fit model
trainer = Trainer(**trainer_options)

<<<<<<< HEAD
class LocalModelNoStep(LightTrainDataloader, TestModelBase):
def test_end(self, outputs):
return {}
=======
assert trainer.callbacks[0] == test_callback
assert test_callback.on_init_start_called
assert test_callback.on_init_end_called
assert not test_callback.on_fit_start_called
assert not test_callback.on_fit_start_called
>>>>>>> Switch to on_.*_start()

# Misconfig when neither test_step or test_end is implemented
with pytest.raises(MisconfigurationException):
model = LocalModel(hparams)
Trainer().test(model)
trainer.fit(model)

<<<<<<< HEAD
# Misconfig when neither test_step or test_end is implemented
with pytest.raises(MisconfigurationException):
model = LocalModelNoStep(hparams)
Trainer().test(model)
=======
assert test_callback.on_fit_start_called
assert test_callback.on_fit_end_called
assert test_callback.on_epoch_start_called
Expand All @@ -932,20 +965,11 @@ def test_end(self, outputs):
assert test_callback.on_validation_end_called
assert not test_callback.on_test_start_called
assert not test_callback.on_test_end_called
>>>>>>> Switch to on_.*_start()

# No exceptions when one or both of test_step or test_end are implemented
model = LocalModelNoEnd(hparams)
Trainer().test(model)
trainer.test()

<<<<<<< HEAD
model = LightningTestModel(hparams)
Trainer().test(model)
=======
assert test_callback.on_test_start_called
assert test_callback.on_test_end_called
>>>>>>> Switch to on_.*_start()


# if __name__ == '__main__':
# pytest.main([__file__])

0 comments on commit 15a914c

Please sign in to comment.