Skip to content

Commit 15a914c

Browse files
committed
fix the mess after rebasing
1 parent 85f47f4 commit 15a914c

File tree

4 files changed

+75
-51
lines changed

4 files changed

+75
-51
lines changed

pytorch_lightning/trainer/callback_hook.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def __init__(self):
1010
# this is just a summary on variables used in this abstract class,
1111
# the proper values/initialisation should be done in child class
1212
self.callbacks: list[Callback] = []
13-
self.get_model: Callable = None
13+
self.get_model: Callable = ...
1414

1515
def on_init_start(self):
1616
"""Called when the trainer initialization begins."""

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,10 @@ def __init__(self):
174174
self.progress_bar_refresh_rate = None
175175

176176
# Callback system
177-
self.on_validation_start: Callable = None
178-
self.on_validation_end: Callable = None
179-
self.on_test_start: Callable = None
180-
self.on_test_end: Callable = None
177+
self.on_validation_start: Callable = ...
178+
self.on_validation_end: Callable = ...
179+
self.on_test_start: Callable = ...
180+
self.on_test_end: Callable = ...
181181

182182
@abstractmethod
183183
def copy_trainer_model_properties(self, model):

pytorch_lightning/trainer/training_loop.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,12 @@ def __init__(self):
235235
# Callback system
236236
self.callbacks: list[Callback] = []
237237
self.max_steps = None
238-
self.on_train_start: Callable = None
239-
self.on_train_end: Callable = None
240-
self.on_batch_start: Callable = None
241-
self.on_batch_end: Callable = None
242-
self.on_epoch_start: Callable = None
243-
self.on_epoch_end: Callable = None
238+
self.on_train_start: Callable = ...
239+
self.on_train_end: Callable = ...
240+
self.on_batch_start: Callable = ...
241+
self.on_batch_end: Callable = ...
242+
self.on_epoch_start: Callable = ...
243+
self.on_epoch_end: Callable = ...
244244

245245
@property
246246
def max_nb_epochs(self):

tests/test_trainer.py

Lines changed: 64 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@
2323
LightValStepFitSingleDataloaderMixin,
2424
LightTrainDataloader,
2525
LightTestDataloader,
26+
LightValidationMixin,
27+
LightTestMixin
2628
)
2729
from pytorch_lightning.core.lightning import load_hparams_from_tags_csv
2830
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
2931
from pytorch_lightning.utilities.debugging import MisconfigurationException
32+
from pytorch_lightning import Callback
3033

3134

3235
def test_no_val_module(tmpdir):
@@ -792,15 +795,15 @@ def test_benchmark_option(tmpdir):
792795
tutils.reset_seed()
793796

794797
class CurrentTestModel(
795-
LightningValidationMultipleDataloadersMixin,
796-
LightningTestModelBase
798+
LightValidationMultipleDataloadersMixin,
799+
LightTrainDataloader,
800+
TestModelBase
797801
):
798802
pass
799803

800804
hparams = tutils.get_hparams()
801805
model = CurrentTestModel(hparams)
802806

803-
<<<<<<< HEAD
804807
# verify torch.backends.cudnn.benchmark is not turned on
805808
assert not torch.backends.cudnn.benchmark
806809

@@ -820,7 +823,53 @@ class CurrentTestModel(
820823

821824
# verify torch.backends.cudnn.benchmark is not turned off
822825
assert torch.backends.cudnn.benchmark
823-
=======
826+
827+
828+
def test_testpass_overrides(tmpdir):
829+
hparams = tutils.get_hparams()
830+
831+
class LocalModel(LightTrainDataloader, TestModelBase):
832+
pass
833+
834+
class LocalModelNoEnd(LightTrainDataloader, LightTestDataloader, LightEmptyTestStep, TestModelBase):
835+
pass
836+
837+
class LocalModelNoStep(LightTrainDataloader, TestModelBase):
838+
def test_end(self, outputs):
839+
return {}
840+
841+
# Misconfig when neither test_step or test_end is implemented
842+
with pytest.raises(MisconfigurationException):
843+
model = LocalModel(hparams)
844+
Trainer().test(model)
845+
846+
# Misconfig when neither test_step or test_end is implemented
847+
with pytest.raises(MisconfigurationException):
848+
model = LocalModelNoStep(hparams)
849+
Trainer().test(model)
850+
851+
# No exceptions when one or both of test_step or test_end are implemented
852+
model = LocalModelNoEnd(hparams)
853+
Trainer().test(model)
854+
855+
model = LightningTestModel(hparams)
856+
Trainer().test(model)
857+
858+
859+
def test_trainer_callback_system(tmpdir):
860+
"""Test the callback system."""
861+
862+
class CurrentTestModel(
863+
LightTrainDataloader,
864+
LightTestMixin,
865+
LightValidationMixin,
866+
TestModelBase,
867+
):
868+
pass
869+
870+
hparams = tutils.get_hparams()
871+
model = CurrentTestModel(hparams)
872+
824873
class TestCallback(Callback):
825874
def __init__(self):
826875
super().__init__()
@@ -880,46 +929,30 @@ def on_test_start(self, trainer, pl_module):
880929

881930
def on_test_end(self, trainer, pl_module):
882931
self.on_test_end_called = True
883-
>>>>>>> Add trainer and pl_module args to callback methods
884932

933+
test_callback = TestCallback()
885934

886-
def test_testpass_overrides(tmpdir):
887-
hparams = tutils.get_hparams()
935+
trainer_options = {}
936+
trainer_options['callbacks'] = [test_callback]
937+
trainer_options['max_epochs'] = 1
938+
trainer_options['val_percent_check'] = 0.1
939+
trainer_options['train_percent_check'] = 0.2
940+
trainer_options['show_progress_bar'] = False
888941

889-
<<<<<<< HEAD
890-
class LocalModel(LightTrainDataloader, TestModelBase):
891-
pass
892-
=======
893942
assert not test_callback.on_init_start_called
894943
assert not test_callback.on_init_end_called
895-
>>>>>>> Switch to on_.*_start()
896944

897-
class LocalModelNoEnd(LightTrainDataloader, LightTestDataloader, LightEmptyTestStep, TestModelBase):
898-
pass
945+
# fit model
946+
trainer = Trainer(**trainer_options)
899947

900-
<<<<<<< HEAD
901-
class LocalModelNoStep(LightTrainDataloader, TestModelBase):
902-
def test_end(self, outputs):
903-
return {}
904-
=======
905948
assert trainer.callbacks[0] == test_callback
906949
assert test_callback.on_init_start_called
907950
assert test_callback.on_init_end_called
908951
assert not test_callback.on_fit_start_called
909952
assert not test_callback.on_fit_start_called
910-
>>>>>>> Switch to on_.*_start()
911953

912-
# Misconfig when neither test_step or test_end is implemented
913-
with pytest.raises(MisconfigurationException):
914-
model = LocalModel(hparams)
915-
Trainer().test(model)
954+
trainer.fit(model)
916955

917-
<<<<<<< HEAD
918-
# Misconfig when neither test_step or test_end is implemented
919-
with pytest.raises(MisconfigurationException):
920-
model = LocalModelNoStep(hparams)
921-
Trainer().test(model)
922-
=======
923956
assert test_callback.on_fit_start_called
924957
assert test_callback.on_fit_end_called
925958
assert test_callback.on_epoch_start_called
@@ -932,20 +965,11 @@ def test_end(self, outputs):
932965
assert test_callback.on_validation_end_called
933966
assert not test_callback.on_test_start_called
934967
assert not test_callback.on_test_end_called
935-
>>>>>>> Switch to on_.*_start()
936968

937-
# No exceptions when one or both of test_step or test_end are implemented
938-
model = LocalModelNoEnd(hparams)
939-
Trainer().test(model)
969+
trainer.test()
940970

941-
<<<<<<< HEAD
942-
model = LightningTestModel(hparams)
943-
Trainer().test(model)
944-
=======
945971
assert test_callback.on_test_start_called
946972
assert test_callback.on_test_end_called
947-
>>>>>>> Switch to on_.*_start()
948-
949973

950974
# if __name__ == '__main__':
951975
# pytest.main([__file__])

0 commit comments

Comments
 (0)