Skip to content

Commit 6314a9c

Browse files
Update main.py
Normalize train and validate if statements
1 parent 78186fd commit 6314a9c

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

imagenet/main.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,11 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
290290
# measure data loading time
291291
data_time.update(time.time() - end)
292292

293-
if torch.cuda.is_available():
293+
if args.gpu is not None:
294294
images = images.cuda(args.gpu, non_blocking=True)
295-
if torch.cuda.is_available():
295+
if args.gpu is not None:
296296
target = target.cuda(args.gpu, non_blocking=True)
297-
297+
298298
# compute output
299299
output = model(images)
300300
loss = criterion(output, target)
@@ -334,9 +334,9 @@ def validate(val_loader, model, criterion, args):
334334
with torch.no_grad():
335335
end = time.time()
336336
for i, (images, target) in enumerate(val_loader):
337-
if torch.cuda.is_available():
337+
if args.gpu is not None:
338338
images = images.cuda(args.gpu, non_blocking=True)
339-
if torch.cuda.is_available():
339+
if args.gpu is not None:
340340
target = target.cuda(args.gpu, non_blocking=True)
341341

342342
# compute output

0 commit comments

Comments
 (0)