Skip to content

Commit 57997c4

Browse files
author
Quentin Anthony
committed
Normalize train and validate GPU movement conditions
1 parent 78186fd commit 57997c4

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

imagenet/main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,9 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
290290
# measure data loading time
291291
data_time.update(time.time() - end)
292292

293-
if torch.cuda.is_available():
293+
if args.gpu is not None:
294294
images = images.cuda(args.gpu, non_blocking=True)
295-
if torch.cuda.is_available():
295+
if args.gpu is not None:
296296
target = target.cuda(args.gpu, non_blocking=True)
297297

298298
# compute output
@@ -334,9 +334,9 @@ def validate(val_loader, model, criterion, args):
334334
with torch.no_grad():
335335
end = time.time()
336336
for i, (images, target) in enumerate(val_loader):
337-
if torch.cuda.is_available():
337+
if args.gpu is not None:
338338
images = images.cuda(args.gpu, non_blocking=True)
339-
if torch.cuda.is_available():
339+
if args.gpu is not None:
340340
target = target.cuda(args.gpu, non_blocking=True)
341341

342342
# compute output

0 commit comments

Comments
 (0)