from ast import arg import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torch.backends.cudnn as cudnn import re import os import argparse import random import numpy as np import time import json from loader import Loader2, Loader from utils import setup_logger from sklearn.cluster import KMeans from custom_datasets import trans_dict, cls_dict, get_num_images from models import models_dict random.seed(0) def get_args(): parser = argparse.ArgumentParser(description='Active Learning Process') # Dataset-related argument parser.add_argument('--dataset', '-d', default='cifar10', type=str, help='Name of the dataset.') parser.add_argument('--datapath', default='DATAPATH', type=str, help='Path to the dataset.') # Model-related argument parser.add_argument('--net', '-n', default='vgg16', type=str, help='Name of the model.') # AL-related arguments parser.add_argument('--sorted_dataset_path', default='', type=str, help='path to the unlabeled_pool.txt generated by kmeans.py') parser.add_argument('--sort', default='high2low', choices=['low2high', 'high2low', 'uniform'], type=str, help='Sorting order.') parser.add_argument('--first', default='high1st', choices=['low1st', 'high1st', 'uni', 'rand'], type=str, help='First choice.') parser.add_argument('--sampling', '-s', type=str, required=True, default='confidence', choices=['confidence', 'entropy', 'loss', 'rand'], help='Types of sampling.') parser.add_argument('--cycles', default=None, type=int, help='Number of cycles.') parser.add_argument('--start_cycle', default=0, type=int, help='Starting cycle.') parser.add_argument('--beta', default=1.0, type=float, help='Balance factor.') parser.add_argument('--resume', '-r', default=None, type=str, help='Checkpoint path for resuming training.') parser.add_argument('--save', default=None, type=str, help='Save path for trained models.') parser.add_argument('--addendum', default=5000, type=int, help='Length of unlabeled pool for labeling.') # Training-related arguments parser.add_argument("--milestones", nargs='+', type=int, default=[30, 60, 90], help='List of epoch milestones.') parser.add_argument('--epochs', default=100, type=int, help='Number of training epochs.') parser.add_argument('--batch_size', default=128, type=int, help='Batch size for training.') parser.add_argument('--lr', default=0.1, type=float, help='Learning rate for training.') parser.add_argument('--per_samples_list', nargs='+', default=[10, 10, 10, 10, 10, 10, 10, 10, 10, 10], type=int, help='Percentage of labeled samples in each cycle.') parser.add_argument('--momentum', default=0.9, type=float, help='Momentum for optimizer.') parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay for optimizer.') args = parser.parse_args() print(f'saving in {args.save}') return args def train(models, criterion, optimizer, epoch, trainloader): models['backbone'].train() train_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(trainloader): inputs, targets = inputs.to(device), targets.to(device) optimizers['backbone'].zero_grad() outputs = models['backbone'](inputs) loss = criterion(outputs, targets) loss.backward() optimizers['backbone'].step() train_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() logger.info('Train:[{}] Loss: {:.3f} | Acc: {:.3f}'.format(epoch, train_loss/(batch_idx+1), 100.*correct/total)) def test(models, criterion, epoch, cycle): global best_acc models['backbone'].eval() test_loss = 0 correct = 0 total = 0 save_label = True if epoch == args.epochs - 1 else False with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(testloader): inputs, targets = inputs.to(device), targets.to(device) outputs, feat = models['backbone'](inputs, is_feat=True) loss = criterion(outputs, targets) test_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() logger.info('Test: [{}] Loss: {:.3f} | Acc: {:.3f}'.format(epoch, test_loss/(batch_idx+1), 100.*correct/total)) # Save checkpoint. acc = 100.*correct/total if acc > best_acc: state = { 'net': models['backbone'].state_dict(), 'acc': acc, 'epoch': epoch } torch.save(state, os.path.join(args.save, 'main_{}.pth'.format(cycle))) logger.info('Saved main_{}.pth with acc={}!'.format(cycle, acc)) best_acc = acc def get_class_balanced(models, samples, cycle): '''class-balanced sampling (pseudo labeling)''' # dictionary with args.num_classes keys as class labels net = models['backbone'] class_dict = {} [class_dict.setdefault(x,[]) for x in range(args.num_classes)] sub5k = Loader2(path=os.path.join(args.datapath, args.dataset), is_train=False, transform=transform_test, path_list=samples) ploader = torch.utils.data.DataLoader(sub5k, batch_size=1, shuffle=False, num_workers=2) # overflow goes into remaining remaining = [] net.eval() with torch.no_grad(): for idx, (inputs, targets) in enumerate(ploader): inputs, targets = inputs.to(device), targets.to(device) outputs = net(inputs) _, predicted = outputs.max(1) if len(class_dict[predicted.item()]) < 100: # add until len(class_dict) == 100 class_dict[predicted.item()].append(samples[idx]) else: # add the rest of them in remaining remaining.append(samples[idx]) # progress_bar(idx, len(ploader)) sample2k = [] for items in class_dict.values(): if len(items) == 100: sample2k.extend(items) else: # supplement samples from remaining sample2k.extend(items) add = 100 - len(items) sample2k.extend(remaining[:add]) remaining = remaining[add:] return sample2k def get_confidence(models, samples, cycle): '''default: confidence sampling (pseudo labeling) return 1k samples w/ lowest top1 score''' net = models['backbone'] sub5k = Loader2(path=os.path.join(args.datapath, args.dataset), is_train=False, transform=transform_test, path_list=samples) ploader = torch.utils.data.DataLoader(sub5k, batch_size=1, shuffle=False, num_workers=2) top1_scores = [] net.eval() with torch.no_grad(): for idx, (inputs, targets) in enumerate(ploader): inputs, targets = inputs.to(device), targets.to(device) outputs = net(inputs) scores, predicted = outputs.max(1) outputs = F.normalize(outputs, dim=1) # save top1 confidence score probs = F.softmax(outputs, dim=1) top1_scores.append(probs[0][predicted.item()].cpu()) idx = np.argsort(top1_scores) samples = np.array(samples) return samples[idx[:num_samples_list[cycle]]] def get_entropy(models, samples, cycle): '''entropy sampling''' net = models['backbone'] sub5k = Loader2(path=os.path.join(args.datapath, args.dataset), is_train=False, transform=transform_test, path_list=samples) ploader = torch.utils.data.DataLoader(sub5k, batch_size=1, shuffle=False, num_workers=2) top1_scores = [] net.eval() with torch.no_grad(): for idx, (inputs, targets) in enumerate(ploader): inputs, targets = inputs.to(device), targets.to(device) outputs = net(inputs) e = -1.0 * torch.sum(F.softmax(outputs, dim=1) * F.log_softmax(outputs, dim=1), dim=1) top1_scores.append(e.view(e.size(0)).cpu()) idx = np.argsort(top1_scores) samples = np.array(samples) return samples[idx[-num_samples_list[cycle]:]] def get_kmeans(models, samples, cycle): '''k means sampling''' net = models['backbone'] sub5k = Loader2(path=os.path.join(args.datapath, args.dataset), is_train=False, transform=transform_test, path_list=samples) ploader = torch.utils.data.DataLoader(sub5k, batch_size=1, shuffle=False, num_workers=2) feats = [] net.eval() with torch.no_grad(): for idx, (inputs, targets) in enumerate(ploader): inputs, targets = inputs.to(device), targets.to(device) _, feat = net(inputs, is_feat=True) feats.append(np.array(torch.squeeze(feat).cpu())) k_means = KMeans(init='k-means++', n_clusters=num_samples_list[cycle], n_init=10).fit(feats) distances = k_means.transform(feats) # 每个数据点到每个簇心的距离,维度-->(样本个数,类簇个数) center_idx = np.argmin(distances, axis=0) # 获取distances中每列最小的元素值索引 samples = np.array(samples) return samples[center_idx] def get_random(models, samples, cycle): '''random sampling''' idx = [random.randint(0, len(samples)-1) for _ in range(num_samples_list[cycle])] samples = np.array(samples) return samples[idx] def save(name, file): np.save(name, file) print(name + ' saved!') if __name__ == '__main__': args = get_args() device = 'cuda' if torch.cuda.is_available() else 'cpu' best_acc = 0 args.num_images = get_num_images(args.dataset) num_samples_list = [int(i*0.01*args.num_images) for i in args.per_samples_list] args.cycles = len(num_samples_list) if args.cycles is None else args.cycles print(f'num samples list: {num_samples_list}') # Data print('==> Preparing data..') assert args.dataset in cls_dict args.num_classes = cls_dict[args.dataset] transform_train, transform_test = trans_dict[args.dataset] testset = Loader(path=os.path.join(args.datapath, args.dataset), is_train=False, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) # Model print('=> Loading net {}'.format(args.net)) net = models_dict[args.net](num_classes=args.num_classes) models = {'backbone': net} net = torch.nn.DataParallel(net) cudnn.benchmark = True labeled_path = os.path.join(args.save, 'labeled') os.makedirs(labeled_path, exist_ok=True) logger = setup_logger(name='main', output=args.save) logger.info(args) logger.info(f'==> Saving in {args.save}') sampling_dict = { 'class_balanced': get_class_balanced, 'confidence': get_confidence, 'entropy': get_entropy, 'rand':get_random, 'kmeans': get_kmeans, } get_sampling = sampling_dict[args.sampling] auto_resume = False if auto_resume: main_best_file = os.path.join(args.save, 'main_best.txt') if os.path.exists(main_best_file): with open(main_best_file) as f: lines = f.readlines() if len(lines) > 1: latest_cycle = int(lines[-1].split()[0]) args.resume = os.path.join(args.save, f"main_{latest_cycle}.pth") args.start_cycle = latest_cycle + 1 print(f'Auto resume from {args.resume}, start from {args.start_cycle}') print('==> Training..') start_time = time.time() labeled = [] for cycle in range(args.start_cycle, args.cycles): criterion = criterion_test = nn.CrossEntropyLoss() optim_backbone = optim.SGD(models['backbone'].parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) sched_backbone = torch.optim.lr_scheduler.MultiStepLR(optim_backbone, milestones=args.milestones) optimizers = {'backbone': optim_backbone, } schedulers = {'backbone': sched_backbone, } best_acc = 0 logger.info('Cycle: {}'.format(cycle)) if cycle == args.start_cycle: if args.sorted_dataset_path.endswith('.npy'): sorted_dataset = np.load(args.sorted_dataset_path) elif args.sorted_dataset_path.endswith('.txt'): with open(os.path.join(args.sorted_dataset_path), 'r') as f: sorted_dataset = [line.strip() for line in f.readlines()] else: raise ValueError(args.sorted_dataset_path) if cycle == 0: begin_idx = 0 end_idx = int(args.beta*args.addendum) samples = sorted_dataset[:end_idx] logger.info(f'sample range: {0} -> {end_idx} unlabeled lenth: {len(samples)}') elif cycle == args.cycles - 1: begin_idx = int(args.num_images-args.beta*args.addendum) end_idx = args.num_images ori_samples = sorted_dataset[begin_idx:] samples = np.setdiff1d(ori_samples, labeled) logger.info(f'sample range: {begin_idx} -> {end_idx} lenth: {len(ori_samples)} unlabeled lenth: {len(samples)}') else: begin_idx = int((cycle+1)*args.addendum - (args.beta-1)/2*args.addendum) end_idx = int((cycle+2)*args.addendum + (args.beta-1)/2*args.addendum) ori_samples = sorted_dataset[begin_idx:end_idx] samples = np.setdiff1d(ori_samples, labeled) logger.info(f'sample range: {begin_idx} -> {end_idx} lenth: {len(ori_samples)} unlabeled lenth: {len(samples)}') if cycle > 0: print('>> Getting previous checkpoint') if args.start_cycle > 0 and cycle == args.start_cycle: assert os.path.exists(args.resume) checkpoint = torch.load(args.resume) labeled = np.load(os.path.join(labeled_path, f'labeled_{cycle}.pth')) best_acc = checkpoint['acc'] else: checkpoint = torch.load(os.path.join(args.save, f'main_{cycle-1}.pth')) models['backbone'].load_state_dict(checkpoint['net']) print(f"cycle {cycle-1} best acc: {best_acc}") sample2k = get_sampling(models, samples, cycle) else: samples = np.array(samples) k = num_samples_list[cycle] if args.first == 'uni': sample2k = samples[[int(j*args.addendum/k) for j in range(k)]] if (args.large or args.small) else samples[[int(2*j*args.addendum/k) for j in range(k)]] elif args.first == 'high1st': sample2k = samples[:k] elif args.first == 'low1st': sample2k = samples[-k:] elif args.first == 'rand': sample2k = np.array(random.sample(list(samples), k)) else: raise ValueError('Sort method {} is not supported!'.format(args.first)) labeled.extend(sample2k) np.save(os.path.join(labeled_path, f'labeled_{cycle}.pth'), labeled) logger.info(f'>> Labeled length: {len(labeled)}') trainset = Loader2(path=os.path.join(args.datapath, args.dataset), is_train=True, transform=transform_train, path_list=labeled) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) for epoch in range(args.epochs): train(models, criterion, optimizers, epoch, trainloader) test(models, criterion_test, epoch, cycle) present_time = time.time() epochs_before = epoch - args.start_epoch + 1 + cycle * args.epochs epochs_after = args.epochs - epoch - 1 + (args.cycles - cycle - 1) * args.epochs eta = (present_time - start_time) / epochs_before * epochs_after eta = time.strftime("%dd %H:%M:%S", time.gmtime(eta)) print('Eta: {}'.format(eta)) schedulers['backbone'].step() with open(os.path.join(args.save, 'main_best.txt'), 'a') as f: f.write(str(cycle) + ' ' + str(best_acc)+'\n') print('done')