diff --git a/imagenet/main.py b/imagenet/main.py index 8d41291745..d5731ac1da 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -292,7 +292,7 @@ def train(train_loader, model, criterion, optimizer, epoch, args): if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True) - if torch.cuda.is_available(): + if args.gpu is not None: target = target.cuda(args.gpu, non_blocking=True) # compute output @@ -336,7 +336,7 @@ def validate(val_loader, model, criterion, args): for i, (images, target) in enumerate(val_loader): if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True) - if torch.cuda.is_available(): + if args.gpu is not None: target = target.cuda(args.gpu, non_blocking=True) # compute output