-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
30 changed files
with
4,237 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Being Bayesian about Categorical Probability | ||
|
||
This repository is the official implementation of ICML'2020 paper ["Being Bayesian about Categorical Probability."](https://arxiv.org/abs/2002.07965) | ||
|
||
The proposed framework, called belief matching framework, regards the categorical probability as a random variable and then constructs the Dirichlet target distribution over the categorical distribution by means of the Bayesian inference. Then, the neural network is trained to match its approximate distribution to the target distribution, which can be implemented by replacing only the softmax-cross entropy loss with the belief matching loss. | ||
|
||
The code is designed to run on ```Python >= 3.5``` using the dependencies listed in ```requirements.txt```. You can install the dependencies by | ||
|
||
``` | ||
$ pip3 install -r requirements.txt. | ||
``` | ||
|
||
|
||
|
||
## Training | ||
|
||
Experimental results presented in the paper can be reproduced by following instructions. | ||
|
||
**CIFAR** | ||
|
||
Following scripts train ResNet-18 and ResNet-50 with the belief matching loss on CIFAR-10 and CIFAR-100 (use ```--coeff -1.0``` to train neural nets with the softmax-cross entropy loss). | ||
``` | ||
$ python cifar_trainer.py --arch resnet18 --coeff 0.01 --dataset cifar10 --save-dir benchmark --gpu 0 | ||
$ python cifar_trainer.py --arch resnet18 --coeff 0.003 --dataset cifar100 --save-dir benchmark --gpu 0 | ||
$ python cifar_trainer.py --arch resnet50 --coeff 0.003 --dataset cifar10 --save-dir benchmark --gpu 0 | ||
$ python cifar_trainer.py --arch resnet50 --coeff 0.001 --dataset cifar100 --save-dir benchmark --gpu 0 | ||
``` | ||
**ImageNet** | ||
|
||
Following scripts train ResNext-50 and ResNext-101 with the belief matching loss on ImageNet (use ```--coeff -1.0``` to train neural nets with the softmax-cross entropy loss). | ||
``` | ||
$ python imagenet_trainer.py --arch ResNext50 --coeff 0.001 --data DATA_DIR --save-dir benchmark | ||
$ python imagenet_trainer.py --arch ResNext101 --coeff 0.0001 --data DATA_DIR --save-dir benchmark | ||
``` | ||
|
||
|
||
|
||
Instructions and codes for transfer learning and semi-supervised learning are in ```transfer_learning``` and ```semi_supervised_learning```, respectively. | ||
|
||
|
||
|
||
## Reference | ||
|
||
Our code is based on the following public repositories: | ||
* CIFAR: https://github.com/facebookresearch/mixup-cifar10 | ||
* ImageNet: https://github.com/hongyi-zhang/Fixup/tree/master/imagenet | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,310 @@ | ||
import sys | ||
import torchvision | ||
import argparse | ||
import os | ||
import shutil | ||
import time | ||
import numpy as np | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.nn.parallel | ||
import torch.backends.cudnn as cudnn | ||
import torch.optim | ||
import torch.utils.data | ||
import torchvision.transforms as transforms | ||
import torchvision.datasets as datasets | ||
import resnet | ||
|
||
from utils import AverageMeter, save_checkpoint, accuracy | ||
from loss import BeliefMatchingLoss | ||
|
||
parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch') | ||
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet32') | ||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', | ||
help='number of data loading workers (default: 4)') | ||
parser.add_argument('--epochs', default=200, type=int, metavar='N', | ||
help='number of total epochs to run') | ||
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', | ||
help='manual epoch number (useful on restarts)') | ||
parser.add_argument('-b', '--batchsize', default=128, type=int, | ||
metavar='N', help='mini-batch size (default: 128)') | ||
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, | ||
metavar='LR', help='initial learning rate') | ||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', | ||
help='momentum') | ||
parser.add_argument('--weightdecay', '--wd', default=1e-4, type=float, | ||
metavar='W', help='weight decay (default: 5e-4)') | ||
parser.add_argument('--print-freq', '-p', default=50, type=int, | ||
metavar='N', help='print frequency (default: 20)') | ||
parser.add_argument('--resume', default='', type=str, metavar='PATH', | ||
help='path to latest checkpoint (default: none)') | ||
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', | ||
help='evaluate model on validation set') | ||
parser.add_argument('--pretrained', dest='pretrained', action='store_true', | ||
help='use pre-trained model') | ||
parser.add_argument('--half', dest='half', action='store_true', | ||
help='use half-precision(16-bit) ') | ||
parser.add_argument('--save-dir', dest='save_dir', | ||
help='The directory used to save the trained models', | ||
default='save_temp', type=str) | ||
parser.add_argument('--save-every', dest='save_every', | ||
help='Saves checkpoints at every specified number of epochs', | ||
type=int, default=10) | ||
parser.add_argument('--gpu', default='1,2', type=str) | ||
parser.add_argument('--coeff', default=1e-2, type=float, help='Coefficient to KL term in BM loss. Set -1 to use CrossEntropy Loss') | ||
parser.add_argument('--prior', default=1.0, type=float, help='Dirichlet prior parameter') | ||
parser.add_argument('--dataset', default='cifar10', type=str) | ||
parser.add_argument('--num-eval', dest='num_eval', default=100, type=int, help='Evaluation count for MC dropout') | ||
|
||
best_prec1 = 0 | ||
test_error_best = -1 | ||
|
||
def main(): | ||
global args, best_prec1 | ||
args = parser.parse_args() | ||
|
||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu | ||
|
||
if args.dataset == 'cifar10': | ||
mean = (0.4914, 0.4822, 0.4465) | ||
std = (0.2023, 0.1994, 0.2010) | ||
else: | ||
mean = (0.5071, 0.4867, 0.4408) | ||
std = (0.2675, 0.2565, 0.2761) | ||
|
||
transform_train = transforms.Compose([ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean, std), | ||
]) # meanstd transformation | ||
|
||
transform_test = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean, std), | ||
]) | ||
|
||
if(args.dataset == 'cifar10'): | ||
print("| Preparing CIFAR-10 dataset...") | ||
sys.stdout.write("| ") | ||
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) | ||
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test) | ||
num_classes = 10 | ||
elif(args.dataset == 'cifar100'): | ||
print("| Preparing CIFAR-100 dataset...") | ||
sys.stdout.write("| ") | ||
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) | ||
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=False, transform=transform_test) | ||
num_classes = 100 | ||
|
||
# Creating data indices for training and validation splits: | ||
validation_split = 0.2 | ||
dataset_size = len(trainset) | ||
num_val = int(np.floor(validation_split * dataset_size)) | ||
|
||
trainset, valset = torch.utils.data.random_split(trainset, [dataset_size-num_val, num_val]) | ||
|
||
train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batchsize, shuffle=True, num_workers=args.workers, pin_memory=True) | ||
val_loader = torch.utils.data.DataLoader(valset, batch_size=100, shuffle=True, num_workers=args.workers, pin_memory=True) | ||
test_loader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=args.workers, pin_memory=True) | ||
|
||
# Check the save_dir exists or not | ||
if not os.path.exists(args.save_dir): | ||
os.makedirs(args.save_dir) | ||
|
||
if args.arch == 'resnet50': | ||
model = resnet.ResNet50(num_classes) | ||
elif args.arch =='resnet18': | ||
model = resnet.ResNet18(num_classes) | ||
|
||
model = torch.nn.DataParallel(model) | ||
model.cuda() | ||
|
||
# optionally resume from a checkpoint | ||
if args.resume: | ||
if os.path.isfile(args.resume): | ||
print("=> loading checkpoint '{}'".format(args.resume)) | ||
checkpoint = torch.load(args.resume) | ||
args.start_epoch = checkpoint['epoch'] | ||
best_prec1 = checkpoint['best_prec1'] | ||
model.load_state_dict(checkpoint['state_dict']) | ||
print("=> loaded checkpoint '{}' (epoch {})" | ||
.format(args.evaluate, checkpoint['epoch'])) | ||
else: | ||
print("=> no checkpoint found at '{}'".format(args.resume)) | ||
|
||
cudnn.benchmark = True | ||
|
||
# define loss function (criterion) and pptimizer | ||
if args.coeff == -1: | ||
criterion = nn.CrossEntropyLoss().cuda() | ||
else: | ||
criterion = BeliefMatchingLoss(args.coeff, args.prior) | ||
|
||
eval_criterion = nn.CrossEntropyLoss().cuda() | ||
|
||
if args.coeff == -1: | ||
name = '{}-softmax'.format(args.arch) | ||
else: | ||
name = '{}-bm-coeff{}-prior{}'.format(args.arch, args.coeff, args.prior) | ||
name += '-{}'.format(args.dataset) | ||
|
||
optimizer = torch.optim.SGD(model.parameters(), args.lr, | ||
momentum=args.momentum, | ||
weight_decay=args.weightdecay) | ||
|
||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, | ||
milestones=[100, 150], last_epoch=args.start_epoch - 1) | ||
|
||
# Warmup | ||
for param_group in optimizer.param_groups: | ||
param_group['lr'] = args.lr*0.1 | ||
|
||
if args.evaluate: | ||
validate(val_loader, model, criterion) | ||
return | ||
|
||
for epoch in range(args.start_epoch, args.epochs): | ||
if epoch < 6 and epoch > 0: | ||
param_group['lr'] = args.lr*0.1*(epoch*2) | ||
|
||
# train for one epoch | ||
print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) | ||
train(train_loader, model, criterion, optimizer, epoch) | ||
lr_scheduler.step() | ||
|
||
# evaluate on validation set | ||
prec1 = validate(val_loader, model, criterion) | ||
|
||
# remember best prec@1 and save checkpoint | ||
is_best = prec1 > best_prec1 | ||
best_prec1 = max(prec1, best_prec1) | ||
|
||
if is_best: | ||
test_prec1 = validate(test_loader, model, criterion) | ||
save_checkpoint({ | ||
'state_dict': model.state_dict(), | ||
'best_prec1': test_prec1, | ||
'epoch': epoch, | ||
}, is_best, filename=os.path.join(args.save_dir, 'checkpoint_{}.th'.format(name))) | ||
|
||
|
||
def train(train_loader, model, criterion, optimizer, epoch): | ||
""" | ||
Run one train epoch | ||
""" | ||
batch_time = AverageMeter() | ||
data_time = AverageMeter() | ||
losses = AverageMeter() | ||
top1 = AverageMeter() | ||
mi_meter = AverageMeter() | ||
|
||
# switch to train mode | ||
model.train() | ||
|
||
end = time.time() | ||
for i, (input, target) in enumerate(train_loader): | ||
# measure data loading time | ||
data_time.update(time.time() - end) | ||
|
||
target = target.cuda() | ||
input_var = input.cuda() | ||
target_var = target | ||
if args.half: | ||
input_var = input_var.half() | ||
|
||
# compute output | ||
output = model(input_var) | ||
|
||
loss = criterion(output, target_var) | ||
|
||
# compute gradient and do SGD step | ||
optimizer.zero_grad() | ||
loss.backward() | ||
for p in model.parameters(): | ||
nn.utils.clip_grad_norm_(p, 1.) | ||
optimizer.step() | ||
|
||
output = output.float() | ||
loss = loss.float() | ||
# measure accuracy and record loss | ||
prec1 = accuracy(output.data, target)[0] | ||
losses.update(loss.item(), input.size(0)) | ||
top1.update(prec1.item(), input.size(0)) | ||
|
||
# measure elapsed time | ||
batch_time.update(time.time() - end) | ||
end = time.time() | ||
|
||
if i % args.print_freq == 0: | ||
print('Epoch: [{0}][{1}/{2}]\t' | ||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | ||
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' | ||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | ||
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( | ||
epoch, i, len(train_loader), batch_time=batch_time, | ||
data_time=data_time, loss=losses, top1=top1)) | ||
|
||
|
||
def validate(val_loader, model, criterion, mc_eval=False): | ||
""" | ||
Run evaluation | ||
""" | ||
batch_time = AverageMeter() | ||
losses = AverageMeter() | ||
top1 = AverageMeter() | ||
mi_meter = AverageMeter() | ||
|
||
# switch to evaluate mode | ||
model.eval() | ||
|
||
end = time.time() | ||
with torch.no_grad(): | ||
for i, (input, target) in enumerate(val_loader): | ||
target = target.cuda() | ||
input_var = input.cuda() | ||
target_var = target.cuda() | ||
|
||
if args.half: | ||
input_var = input_var.half() | ||
|
||
# compute output | ||
if mc_eval: | ||
outputs = F.softmax(model(input_var)) | ||
for _ in range(args.num_eval-1): | ||
outputs += (F.softmax(model(input_var))) | ||
output = outputs / float(args.num_eval) | ||
else: | ||
output = model(input_var) | ||
loss = criterion(output, target_var) | ||
|
||
output = output.float() | ||
loss = loss.float() | ||
|
||
# measure accuracy and record loss | ||
prec1 = accuracy(output.data, target)[0] | ||
losses.update(loss.item(), input.size(0)) | ||
top1.update(prec1.item(), input.size(0)) | ||
|
||
# measure elapsed time | ||
batch_time.update(time.time() - end) | ||
end = time.time() | ||
|
||
if i % args.print_freq == 0: | ||
print('Test: [{0}/{1}]\t' | ||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | ||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | ||
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( | ||
i, len(val_loader), batch_time=batch_time, loss=losses, | ||
top1=top1)) | ||
|
||
print(' * Prec@1 {top1.avg:.3f}' | ||
.format(top1=top1)) | ||
|
||
return top1.avg #, mi_meter.avg | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.