@@ -483,11 +483,6 @@ def fit(args, model, data_loader):
483
483
# select gpu for horovod process
484
484
if 'horovod' in args .kv_store :
485
485
args .gpus = [args .gpus [hvd .local_rank ()]]
486
- ctx = mx .gpu (hvd .local_rank ())
487
-
488
- tensor1 = mx .nd .zeros (shape = (1 ,), dtype = 'float32' , ctx = ctx )
489
- tensor2 = mx .nd .zeros (shape = (1 ,), dtype = 'float32' , ctx = ctx )
490
- tensor1 , tensor2 = hvd .grouped_allreduce ([tensor1 ,tensor2 ])
491
486
492
487
if args .amp :
493
488
amp .init ()
@@ -579,6 +574,11 @@ def fit(args, model, data_loader):
579
574
params = model .collect_params ()
580
575
if params is not None :
581
576
hvd .broadcast_parameters (params , root_rank = 0 )
577
+ ctx = mx .gpu (hvd .local_rank ())
578
+ tensor1 = mx .nd .zeros (shape = (1 ,), dtype = 'float32' , ctx = ctx )
579
+ tensor2 = mx .nd .zeros (shape = (1 ,), dtype = 'float32' , ctx = ctx )
580
+ tensor1 , tensor2 = hvd .grouped_allreduce ([tensor1 ,tensor2 ])
581
+
582
582
global_metrics = CompositeMeter ()
583
583
if args .mode in ['train_val' , 'train' ]:
584
584
global_metrics .register_metric ('train.loss' , MinMeter ())
0 commit comments