@@ -163,7 +163,7 @@ def parity_automatic_train_with_one_optimizer(ctx):
163163 accumulate_grad_batches = accumulate_grad_batches if not ctx ["vanilla" ] else 1 ,
164164 amp_backend = ctx ["amp_backend" ],
165165 precision = ctx ["precision" ],
166- gpus = 1
166+ gpus = 1 if ctx [ "device" ] == 'cuda' else 0 ,
167167 )
168168 trainer .fit (model )
169169
@@ -174,35 +174,37 @@ def parity_automatic_train_with_one_optimizer(ctx):
174174################## TESTS ################## # noqa E266
175175
176176
177- @pytest .mark .skipif ( not torch . cuda . is_available (), reason = "test requires GPU machine" )
178- @ pytest .mark . parametrize ([ "precision " , "amp_backend" ], [
179- pytest .param ( 16 , "native" ),
180- pytest .param (32 , "native" ),
177+ @pytest .mark .parametrize ([ "precision" , "amp_backend" , "device" ], [
178+ pytest .param ( 16 , "native " , "cuda" ,
179+ marks = pytest .mark . skipif ( not torch . cuda . is_available (), reason = "test requires GPU machine" ) ),
180+ pytest .param (32 , "native" , "cpu" ),
181181])
182- @pytest .mark .parametrize ('accumulate_grad_batches' , [1 , 2 ])
183- def test_parity_automatic_training_with_one_optimizer (tmpdir , amp_backend , precision , accumulate_grad_batches ):
182+ @pytest .mark .parametrize ('accumulate_grad_batches' , [1 ])
183+ def test_parity_automatic_training_with_one_optimizer (tmpdir , amp_backend , precision , device , accumulate_grad_batches ):
184184 """
185185 Test training with accumulated gradients with and within enable_pl_optimizer reaches the same weights
186186 """
187187 # prepare arguments
188188 if accumulate_grad_batches > 1 :
189189 accumulate_grad_batches = np .random .randint (2 , accumulate_grad_batches + 1 )
190190
191- ctx = {}
192- ctx ["tmpdir" ] = tmpdir
193- ctx ["accumulate_grad_batches" ] = accumulate_grad_batches
194- ctx ["amp_backend" ] = amp_backend
195- ctx ["precision" ] = precision
196- ctx ["using_amp" ] = (amp_backend in ["native" ]) and precision == 16
197- ctx ["max_epochs" ] = np .random .randint (1 , 3 )
198- ctx ["limit_train_batches" ] = np .random .randint (11 , 27 )
191+ ctx = dict (
192+ tmpdir = tmpdir ,
193+ accumulate_grad_batches = accumulate_grad_batches ,
194+ amp_backend = amp_backend ,
195+ precision = precision ,
196+ using_amp = (amp_backend in ["native" ]) and precision == 16 ,
197+ max_epochs = np .random .randint (1 , 3 ),
198+ limit_train_batches = np .random .randint (11 , 27 ),
199+ limit_val_batches = 0 ,
200+ initial_weights = {},
201+ enable_pl_optimizer = True ,
202+ mocked = False ,
203+ vanilla = False ,
204+ device = device ,
205+ )
199206 expected_batches = ctx ["max_epochs" ] * ctx ["limit_train_batches" ]
200207 ctx ["expected_batches" ] = expected_batches
201- ctx ["limit_val_batches" ] = 0
202- ctx ["initial_weights" ] = {}
203- ctx ["enable_pl_optimizer" ] = True
204- ctx ["mocked" ] = False
205- ctx ["vanilla" ] = False
206208
207209 model_wi_pl_optimizer = parity_automatic_train_with_one_optimizer (ctx )
208210
0 commit comments