2424from pytorch_lightning .utilities import _TORCH_GREATER_EQUAL_1_6
2525from pytorch_lightning .utilities .exceptions import MisconfigurationException
2626from tests .helpers import BoringModel , RandomDataset
27+ from tests .helpers .skipif import RunIf
2728
2829if _TORCH_GREATER_EQUAL_1_6 :
2930 from pytorch_lightning .callbacks import StochasticWeightAveraging
@@ -114,40 +115,37 @@ def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_proc
114115 assert trainer .lightning_module == model
115116
116117
117- @pytest .mark .skipif (not _TORCH_GREATER_EQUAL_1_6 , reason = "SWA available from PyTorch 1.6.0" )
118- @pytest .mark .skipif (torch .cuda .device_count () < 2 , reason = "test requires multi-GPU machine" )
118+ @RunIf (min_gpus = 2 , min_torch = "1.6.0" )
119119@pytest .mark .skipif (
120120 not os .getenv ("PL_RUNNING_SPECIAL_TESTS" , '0' ) == '1' , reason = "test should be run outside of pytest"
121121)
122122def test_swa_callback_ddp (tmpdir ):
123123 train_with_swa (tmpdir , accelerator = "ddp" , gpus = 2 )
124124
125125
126- @pytest .mark .skipif (not _TORCH_GREATER_EQUAL_1_6 , reason = "SWA available from PyTorch 1.6.0" )
127- @pytest .mark .skipif (torch .cuda .device_count () < 2 , reason = "test requires multi-GPU machine" )
126+ @RunIf (min_gpus = 2 , min_torch = "1.6.0" )
128127def test_swa_callback_ddp_spawn (tmpdir ):
129128 train_with_swa (tmpdir , accelerator = "ddp_spawn" , gpus = 2 )
130129
131130
132- @pytest . mark . skipif ( not _TORCH_GREATER_EQUAL_1_6 , reason = "SWA available from PyTorch 1.6.0" )
131+ @RunIf ( min_torch = " 1.6.0" )
133132@pytest .mark .skipif (platform .system () == "Windows" , reason = "ddp_cpu is not available on Windows" )
134133def test_swa_callback_ddp_cpu (tmpdir ):
135134 train_with_swa (tmpdir , accelerator = "ddp_cpu" , num_processes = 2 )
136135
137136
138- @pytest .mark .skipif (not _TORCH_GREATER_EQUAL_1_6 , reason = "SWA available from PyTorch 1.6.0" )
139- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "test requires a GPU machine" )
137+ @RunIf (min_gpus = 1 , min_torch = "1.6.0" )
140138def test_swa_callback_1_gpu (tmpdir ):
141139 train_with_swa (tmpdir , gpus = 1 )
142140
143141
144- @pytest . mark . skipif ( not _TORCH_GREATER_EQUAL_1_6 , reason = "SWA available from PyTorch 1.6.0" )
142+ @RunIf ( min_torch = " 1.6.0" )
145143@pytest .mark .parametrize ("batchnorm" , (True , False ))
146144def test_swa_callback (tmpdir , batchnorm ):
147145 train_with_swa (tmpdir , batchnorm = batchnorm )
148146
149147
150- @pytest . mark . skipif ( not _TORCH_GREATER_EQUAL_1_6 , reason = "SWA available from PyTorch 1.6.0" )
148+ @RunIf ( min_torch = " 1.6.0" )
151149def test_swa_raises ():
152150 with pytest .raises (MisconfigurationException , match = ">0 integer or a float between 0 and 1" ):
153151 StochasticWeightAveraging (swa_epoch_start = 0 , swa_lrs = 0.1 )
@@ -161,7 +159,7 @@ def test_swa_raises():
161159
162160@pytest .mark .parametrize ('stochastic_weight_avg' , [False , True ])
163161@pytest .mark .parametrize ('use_callbacks' , [False , True ])
164- @pytest . mark . skipif ( not _TORCH_GREATER_EQUAL_1_6 , reason = "SWA available from PyTorch 1.6.0" )
162+ @RunIf ( min_torch = " 1.6.0" )
165163def test_trainer_and_stochastic_weight_avg (tmpdir , use_callbacks , stochastic_weight_avg ):
166164 """Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer"""
167165
0 commit comments