@@ -267,14 +267,17 @@ def configure_optimizers(self):
267267@RunIf (min_cuda_gpus = 2 , fairscale = True )
268268@pytest .mark .parametrize ("strategy" , (pytest .param ("ddp" , marks = RunIf (standalone = True )), "ddp_spawn" ))
269269def test_ddp_strategy_checkpoint_multi_gpu_fairscale_optimizer (tmpdir , strategy ):
270- """Test to ensure that checkpoint is saved correctly when using faircale optimizer."""
270+ """Test to ensure that checkpoint is saved correctly when using fairscale optimizer."""
271271 model = BoringFairScaleOptimizerModel ()
272272 trainer = Trainer (accelerator = "gpu" , devices = 2 , strategy = strategy , max_steps = 1 )
273273
274274 trainer .fit (model )
275275
276276 checkpoint_path = os .path .join (tmpdir , "model.pt" )
277+ # need to broadcast because tmpdir is different on each process
278+ checkpoint_path = trainer .strategy .broadcast (checkpoint_path )
277279 trainer .save_checkpoint (checkpoint_path )
280+ trainer .strategy .barrier () # ensure the checkpoint is saved before load
278281 saved_model = BoringModel .load_from_checkpoint (checkpoint_path )
279282
280283 # Assert model parameters are identical after loading
@@ -297,7 +300,10 @@ def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(tmpdir, strategy):
297300 trainer .fit (model )
298301
299302 checkpoint_path = os .path .join (tmpdir , "model.pt" )
303+ # need to broadcast because tmpdir is different on each process
304+ checkpoint_path = trainer .strategy .broadcast (checkpoint_path )
300305 trainer .save_checkpoint (checkpoint_path )
306+ trainer .strategy .barrier () # ensure the checkpoint is saved before load
301307 saved_model = BoringModel .load_from_checkpoint (checkpoint_path )
302308
303309 # Assert model parameters are identical after loading
0 commit comments