diff --git a/imagenet/main.py b/imagenet/main.py index cc32d50733..8b2cf2f038 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -137,6 +137,7 @@ def main_worker(gpu, ngpus_per_node, args): # For multiprocessing distributed training, rank needs to be the # global rank among all the processes args.rank = args.rank * ngpus_per_node + gpu + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank) dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) # create model @@ -154,20 +155,14 @@ def main_worker(gpu, ngpus_per_node, args): # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. if torch.cuda.is_available(): + model.cuda() if args.gpu is not None: - torch.cuda.set_device(args.gpu) - model.cuda(args.gpu) # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs of the current node. args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) - else: - model.cuda() - # DistributedDataParallel will divide and allocate batch_size to all - # available GPUs if device_ids are not set - model = torch.nn.parallel.DistributedDataParallel(model) + model = torch.nn.parallel.DistributedDataParallel(model) elif args.gpu is not None and torch.cuda.is_available(): torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) @@ -183,7 +178,7 @@ def main_worker(gpu, ngpus_per_node, args): model = torch.nn.DataParallel(model).cuda() if torch.cuda.is_available(): - if args.gpu: + if args.gpu and not args.distributed: device = torch.device('cuda:{}'.format(args.gpu)) else: device = torch.device("cuda") @@ -205,17 +200,11 @@ def main_worker(gpu, ngpus_per_node, args): if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) - if args.gpu is None: - checkpoint = torch.load(args.resume) - elif torch.cuda.is_available(): - # Map model to be loaded to specified single gpu. - loc = 'cuda:{}'.format(args.gpu) - checkpoint = torch.load(args.resume, map_location=loc) + checkpoint = torch.load(args.resume, map_location=device) args.start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] - if args.gpu is not None: - # best_acc1 may be from a checkpoint from a different GPU - best_acc1 = best_acc1.to(args.gpu) + # best_acc1 may be from a checkpoint from a different GPU + best_acc1 = best_acc1.to(device=device) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) @@ -270,7 +259,7 @@ def main_worker(gpu, ngpus_per_node, args): num_workers=args.workers, pin_memory=True, sampler=val_sampler) if args.evaluate: - validate(val_loader, model, criterion, args) + validate(val_loader, model, criterion, device, args) return for epoch in range(args.start_epoch, args.epochs): @@ -281,7 +270,7 @@ def main_worker(gpu, ngpus_per_node, args): train(train_loader, model, criterion, optimizer, epoch, device, args) # evaluate on validation set - acc1 = validate(val_loader, model, criterion, args) + acc1 = validate(val_loader, model, criterion, device, args) scheduler.step() @@ -302,11 +291,11 @@ def main_worker(gpu, ngpus_per_node, args): def train(train_loader, model, criterion, optimizer, epoch, device, args): - batch_time = AverageMeter('Time', ':6.3f') - data_time = AverageMeter('Data', ':6.3f') - losses = AverageMeter('Loss', ':.4e') - top1 = AverageMeter('Acc@1', ':6.2f') - top5 = AverageMeter('Acc@5', ':6.2f') + batch_time = AverageMeter('Time', device, ':6.3f') + data_time = AverageMeter('Data', device, ':6.3f') + losses = AverageMeter('Loss', device, ':.4e') + top1 = AverageMeter('Acc@1', device, ':6.2f') + top5 = AverageMeter('Acc@5', device, ':6.2f') progress = ProgressMeter( len(train_loader), [batch_time, data_time, losses, top1, top5], @@ -347,20 +336,15 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args): progress.display(i + 1) -def validate(val_loader, model, criterion, args): +def validate(val_loader, model, criterion, device, args): def run_validate(loader, base_progress=0): with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(loader): i = base_progress + i - if args.gpu is not None and torch.cuda.is_available(): - images = images.cuda(args.gpu, non_blocking=True) - if torch.backends.mps.is_available(): - images = images.to('mps') - target = target.to('mps') - if torch.cuda.is_available(): - target = target.cuda(args.gpu, non_blocking=True) + images = images.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) # compute output output = model(images) @@ -379,10 +363,10 @@ def run_validate(loader, base_progress=0): if i % args.print_freq == 0: progress.display(i + 1) - batch_time = AverageMeter('Time', ':6.3f', Summary.NONE) - losses = AverageMeter('Loss', ':.4e', Summary.NONE) - top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE) - top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE) + batch_time = AverageMeter('Time', device, ':6.3f', Summary.NONE) + losses = AverageMeter('Loss', device, ':.4e', Summary.NONE) + top1 = AverageMeter('Acc@1', device, ':6.2f', Summary.AVERAGE) + top5 = AverageMeter('Acc@5', device, ':6.2f', Summary.AVERAGE) progress = ProgressMeter( len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))), [batch_time, losses, top1, top5], @@ -422,8 +406,9 @@ class Summary(Enum): class AverageMeter(object): """Computes and stores the average and current value""" - def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): + def __init__(self, name, device, fmt=':f', summary_type=Summary.AVERAGE): self.name = name + self.device = device self.fmt = fmt self.summary_type = summary_type self.reset() @@ -441,13 +426,7 @@ def update(self, val, n=1): self.avg = self.sum / self.count def all_reduce(self): - if torch.cuda.is_available(): - device = torch.device("cuda") - elif torch.backends.mps.is_available(): - device = torch.device("mps") - else: - device = torch.device("cpu") - total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device) + total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=self.device) dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) self.sum, self.count = total.tolist() self.avg = self.sum / self.count