Skip to content

Commit 810bcf3

Browse files
mmarcinkiewicznv-kkudrynski
authored andcommitted
[resnet/mxnet] Apply horovod patch for hvd init
1 parent 9becdf8 commit 810bcf3

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

MxNet/Classification/RN50v1.5/dali.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def add_dali_args(parser):
3131
group.add_argument('--dali-validation-threads', type=int, default=10, help="number of threads" +\
3232
"per GPU for DALI for validation")
3333
group.add_argument('--dali-prefetch-queue', type=int, default=5, help="DALI prefetch queue depth")
34-
group.add_argument('--dali-nvjpeg-memory-padding', type=int, default=256, help="Memory padding value for nvJPEG (in MB)")
34+
group.add_argument('--dali-nvjpeg-memory-padding', type=int, default=64, help="Memory padding value for nvJPEG (in MB)")
3535
group.add_argument('--dali-fuse-decoder', type=int, default=1, help="0 or 1 whether to fuse decoder or not")
3636

3737
group.add_argument('--dali-nvjpeg-width-hint', type=int, default=5980, help="Width hint value for nvJPEG (in pixels)")

MxNet/Classification/RN50v1.5/fit.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -483,11 +483,6 @@ def fit(args, model, data_loader):
483483
# select gpu for horovod process
484484
if 'horovod' in args.kv_store:
485485
args.gpus = [args.gpus[hvd.local_rank()]]
486-
ctx = mx.gpu(hvd.local_rank())
487-
488-
tensor1 = mx.nd.zeros(shape=(1,), dtype='float32', ctx=ctx)
489-
tensor2 = mx.nd.zeros(shape=(1,), dtype='float32', ctx=ctx)
490-
tensor1, tensor2 = hvd.grouped_allreduce([tensor1,tensor2])
491486

492487
if args.amp:
493488
amp.init()
@@ -579,6 +574,11 @@ def fit(args, model, data_loader):
579574
params = model.collect_params()
580575
if params is not None:
581576
hvd.broadcast_parameters(params, root_rank=0)
577+
ctx = mx.gpu(hvd.local_rank())
578+
tensor1 = mx.nd.zeros(shape=(1,), dtype='float32', ctx=ctx)
579+
tensor2 = mx.nd.zeros(shape=(1,), dtype='float32', ctx=ctx)
580+
tensor1, tensor2 = hvd.grouped_allreduce([tensor1,tensor2])
581+
582582
global_metrics = CompositeMeter()
583583
if args.mode in ['train_val', 'train']:
584584
global_metrics.register_metric('train.loss', MinMeter())

0 commit comments

Comments
 (0)