2323 LightValStepFitSingleDataloaderMixin ,
2424 LightTrainDataloader ,
2525 LightTestDataloader ,
26+ LightValidationMixin ,
27+ LightTestMixin
2628)
2729from pytorch_lightning .core .lightning import load_hparams_from_tags_csv
2830from pytorch_lightning .trainer .logging import TrainerLoggingMixin
2931from pytorch_lightning .utilities .debugging import MisconfigurationException
32+ from pytorch_lightning import Callback
3033
3134
3235def test_no_val_module (tmpdir ):
@@ -792,15 +795,15 @@ def test_benchmark_option(tmpdir):
792795 tutils .reset_seed ()
793796
794797 class CurrentTestModel (
795- LightningValidationMultipleDataloadersMixin ,
796- LightningTestModelBase
798+ LightValidationMultipleDataloadersMixin ,
799+ LightTrainDataloader ,
800+ TestModelBase
797801 ):
798802 pass
799803
800804 hparams = tutils .get_hparams ()
801805 model = CurrentTestModel (hparams )
802806
803- < << << << HEAD
804807 # verify torch.backends.cudnn.benchmark is not turned on
805808 assert not torch .backends .cudnn .benchmark
806809
@@ -820,7 +823,53 @@ class CurrentTestModel(
820823
821824 # verify torch.backends.cudnn.benchmark is not turned off
822825 assert torch .backends .cudnn .benchmark
823- == == == =
826+
827+
828+ def test_testpass_overrides (tmpdir ):
829+ hparams = tutils .get_hparams ()
830+
831+ class LocalModel (LightTrainDataloader , TestModelBase ):
832+ pass
833+
834+ class LocalModelNoEnd (LightTrainDataloader , LightTestDataloader , LightEmptyTestStep , TestModelBase ):
835+ pass
836+
837+ class LocalModelNoStep (LightTrainDataloader , TestModelBase ):
838+ def test_end (self , outputs ):
839+ return {}
840+
841+ # Misconfig when neither test_step or test_end is implemented
842+ with pytest .raises (MisconfigurationException ):
843+ model = LocalModel (hparams )
844+ Trainer ().test (model )
845+
846+ # Misconfig when neither test_step or test_end is implemented
847+ with pytest .raises (MisconfigurationException ):
848+ model = LocalModelNoStep (hparams )
849+ Trainer ().test (model )
850+
851+ # No exceptions when one or both of test_step or test_end are implemented
852+ model = LocalModelNoEnd (hparams )
853+ Trainer ().test (model )
854+
855+ model = LightningTestModel (hparams )
856+ Trainer ().test (model )
857+
858+
859+ def test_trainer_callback_system (tmpdir ):
860+ """Test the callback system."""
861+
862+ class CurrentTestModel (
863+ LightTrainDataloader ,
864+ LightTestMixin ,
865+ LightValidationMixin ,
866+ TestModelBase ,
867+ ):
868+ pass
869+
870+ hparams = tutils .get_hparams ()
871+ model = CurrentTestModel (hparams )
872+
824873 class TestCallback (Callback ):
825874 def __init__ (self ):
826875 super ().__init__ ()
@@ -880,46 +929,30 @@ def on_test_start(self, trainer, pl_module):
880929
881930 def on_test_end (self , trainer , pl_module ):
882931 self .on_test_end_called = True
883- >> >> >> > Add trainer and pl_module args to callback methods
884932
933+ test_callback = TestCallback ()
885934
886- def test_testpass_overrides (tmpdir ):
887- hparams = tutils .get_hparams ()
935+ trainer_options = {}
936+ trainer_options ['callbacks' ] = [test_callback ]
937+ trainer_options ['max_epochs' ] = 1
938+ trainer_options ['val_percent_check' ] = 0.1
939+ trainer_options ['train_percent_check' ] = 0.2
940+ trainer_options ['show_progress_bar' ] = False
888941
889- < << << << HEAD
890- class LocalModel (LightTrainDataloader , TestModelBase ):
891- pass
892- == == == =
893942 assert not test_callback .on_init_start_called
894943 assert not test_callback .on_init_end_called
895- >> >> >> > Switch to on_ .* _start ()
896944
897- class LocalModelNoEnd ( LightTrainDataloader , LightTestDataloader , LightEmptyTestStep , TestModelBase ):
898- pass
945+ # fit model
946+ trainer = Trainer ( ** trainer_options )
899947
900- < << << << HEAD
901- class LocalModelNoStep (LightTrainDataloader , TestModelBase ):
902- def test_end (self , outputs ):
903- return {}
904- == == == =
905948 assert trainer .callbacks [0 ] == test_callback
906949 assert test_callback .on_init_start_called
907950 assert test_callback .on_init_end_called
908951 assert not test_callback .on_fit_start_called
909952 assert not test_callback .on_fit_start_called
910- >> >> >> > Switch to on_ .* _start ()
911953
912- # Misconfig when neither test_step or test_end is implemented
913- with pytest .raises (MisconfigurationException ):
914- model = LocalModel (hparams )
915- Trainer ().test (model )
954+ trainer .fit (model )
916955
917- < << << << HEAD
918- # Misconfig when neither test_step or test_end is implemented
919- with pytest .raises (MisconfigurationException ):
920- model = LocalModelNoStep (hparams )
921- Trainer ().test (model )
922- == == == =
923956 assert test_callback .on_fit_start_called
924957 assert test_callback .on_fit_end_called
925958 assert test_callback .on_epoch_start_called
@@ -932,20 +965,11 @@ def test_end(self, outputs):
932965 assert test_callback .on_validation_end_called
933966 assert not test_callback .on_test_start_called
934967 assert not test_callback .on_test_end_called
935- >> >> >> > Switch to on_ .* _start ()
936968
937- # No exceptions when one or both of test_step or test_end are implemented
938- model = LocalModelNoEnd (hparams )
939- Trainer ().test (model )
969+ trainer .test ()
940970
941- < << << << HEAD
942- model = LightningTestModel (hparams )
943- Trainer ().test (model )
944- == == == =
945971 assert test_callback .on_test_start_called
946972 assert test_callback .on_test_end_called
947- >> >> >> > Switch to on_ .* _start ()
948-
949973
950974# if __name__ == '__main__':
951975# pytest.main([__file__])
0 commit comments