diff --git a/cityscapes_train.py b/cityscapes_train.py index 19a4a22..702b2a7 100644 --- a/cityscapes_train.py +++ b/cityscapes_train.py @@ -175,6 +175,7 @@ def train_model(args): model = torch.nn.DataParallel(model).cuda() #multi-card data parallel else: print("single GPU for training") + args.gpu_nums = torch.cuda.device_count() model = model.cuda() #1-card data parallel args.savedir = ( args.savedir + args.dataset + '/'+ network_type +"_M"+ str(M) + 'N' +str(N) + 'bs'