Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
tjoo512 committed Jun 29, 2020
1 parent 2630fe6 commit de18ad1
Show file tree
Hide file tree
Showing 30 changed files with 4,237 additions and 0 deletions.
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Being Bayesian about Categorical Probability

This repository is the official implementation of ICML'2020 paper ["Being Bayesian about Categorical Probability."](https://arxiv.org/abs/2002.07965)

The proposed framework, called belief matching framework, regards the categorical probability as a random variable and then constructs the Dirichlet target distribution over the categorical distribution by means of the Bayesian inference. Then, the neural network is trained to match its approximate distribution to the target distribution, which can be implemented by replacing only the softmax-cross entropy loss with the belief matching loss.

The code is designed to run on ```Python >= 3.5``` using the dependencies listed in ```requirements.txt```. You can install the dependencies by

```
$ pip3 install -r requirements.txt.
```



## Training

Experimental results presented in the paper can be reproduced by following instructions.

**CIFAR**

Following scripts train ResNet-18 and ResNet-50 with the belief matching loss on CIFAR-10 and CIFAR-100 (use ```--coeff -1.0``` to train neural nets with the softmax-cross entropy loss).
```
$ python cifar_trainer.py --arch resnet18 --coeff 0.01 --dataset cifar10 --save-dir benchmark --gpu 0
$ python cifar_trainer.py --arch resnet18 --coeff 0.003 --dataset cifar100 --save-dir benchmark --gpu 0
$ python cifar_trainer.py --arch resnet50 --coeff 0.003 --dataset cifar10 --save-dir benchmark --gpu 0
$ python cifar_trainer.py --arch resnet50 --coeff 0.001 --dataset cifar100 --save-dir benchmark --gpu 0
```
**ImageNet**

Following scripts train ResNext-50 and ResNext-101 with the belief matching loss on ImageNet (use ```--coeff -1.0``` to train neural nets with the softmax-cross entropy loss).
```
$ python imagenet_trainer.py --arch ResNext50 --coeff 0.001 --data DATA_DIR --save-dir benchmark
$ python imagenet_trainer.py --arch ResNext101 --coeff 0.0001 --data DATA_DIR --save-dir benchmark
```



Instructions and codes for transfer learning and semi-supervised learning are in ```transfer_learning``` and ```semi_supervised_learning```, respectively.



## Reference

Our code is based on the following public repositories:
* CIFAR: https://github.com/facebookresearch/mixup-cifar10
* ImageNet: https://github.com/hongyi-zhang/Fixup/tree/master/imagenet

310 changes: 310 additions & 0 deletions cifar_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
import sys
import torchvision
import argparse
import os
import shutil
import time
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import resnet

from utils import AverageMeter, save_checkpoint, accuracy
from loss import BeliefMatchingLoss

parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet32')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batchsize', default=128, type=int,
metavar='N', help='mini-batch size (default: 128)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weightdecay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('--print-freq', '-p', default=50, type=int,
metavar='N', help='print frequency (default: 20)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--half', dest='half', action='store_true',
help='use half-precision(16-bit) ')
parser.add_argument('--save-dir', dest='save_dir',
help='The directory used to save the trained models',
default='save_temp', type=str)
parser.add_argument('--save-every', dest='save_every',
help='Saves checkpoints at every specified number of epochs',
type=int, default=10)
parser.add_argument('--gpu', default='1,2', type=str)
parser.add_argument('--coeff', default=1e-2, type=float, help='Coefficient to KL term in BM loss. Set -1 to use CrossEntropy Loss')
parser.add_argument('--prior', default=1.0, type=float, help='Dirichlet prior parameter')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--num-eval', dest='num_eval', default=100, type=int, help='Evaluation count for MC dropout')

best_prec1 = 0
test_error_best = -1

def main():
global args, best_prec1
args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

if args.dataset == 'cifar10':
mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)
else:
mean = (0.5071, 0.4867, 0.4408)
std = (0.2675, 0.2565, 0.2761)

transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]) # meanstd transformation

transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std),
])

if(args.dataset == 'cifar10'):
print("| Preparing CIFAR-10 dataset...")
sys.stdout.write("| ")
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
num_classes = 10
elif(args.dataset == 'cifar100'):
print("| Preparing CIFAR-100 dataset...")
sys.stdout.write("| ")
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=False, transform=transform_test)
num_classes = 100

# Creating data indices for training and validation splits:
validation_split = 0.2
dataset_size = len(trainset)
num_val = int(np.floor(validation_split * dataset_size))

trainset, valset = torch.utils.data.random_split(trainset, [dataset_size-num_val, num_val])

train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batchsize, shuffle=True, num_workers=args.workers, pin_memory=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=100, shuffle=True, num_workers=args.workers, pin_memory=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=args.workers, pin_memory=True)

# Check the save_dir exists or not
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)

if args.arch == 'resnet50':
model = resnet.ResNet50(num_classes)
elif args.arch =='resnet18':
model = resnet.ResNet18(num_classes)

model = torch.nn.DataParallel(model)
model.cuda()

# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.evaluate, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))

cudnn.benchmark = True

# define loss function (criterion) and pptimizer
if args.coeff == -1:
criterion = nn.CrossEntropyLoss().cuda()
else:
criterion = BeliefMatchingLoss(args.coeff, args.prior)

eval_criterion = nn.CrossEntropyLoss().cuda()

if args.coeff == -1:
name = '{}-softmax'.format(args.arch)
else:
name = '{}-bm-coeff{}-prior{}'.format(args.arch, args.coeff, args.prior)
name += '-{}'.format(args.dataset)

optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weightdecay)

lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
milestones=[100, 150], last_epoch=args.start_epoch - 1)

# Warmup
for param_group in optimizer.param_groups:
param_group['lr'] = args.lr*0.1

if args.evaluate:
validate(val_loader, model, criterion)
return

for epoch in range(args.start_epoch, args.epochs):
if epoch < 6 and epoch > 0:
param_group['lr'] = args.lr*0.1*(epoch*2)

# train for one epoch
print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
train(train_loader, model, criterion, optimizer, epoch)
lr_scheduler.step()

# evaluate on validation set
prec1 = validate(val_loader, model, criterion)

# remember best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)

if is_best:
test_prec1 = validate(test_loader, model, criterion)
save_checkpoint({
'state_dict': model.state_dict(),
'best_prec1': test_prec1,
'epoch': epoch,
}, is_best, filename=os.path.join(args.save_dir, 'checkpoint_{}.th'.format(name)))


def train(train_loader, model, criterion, optimizer, epoch):
"""
Run one train epoch
"""
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
mi_meter = AverageMeter()

# switch to train mode
model.train()

end = time.time()
for i, (input, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)

target = target.cuda()
input_var = input.cuda()
target_var = target
if args.half:
input_var = input_var.half()

# compute output
output = model(input_var)

loss = criterion(output, target_var)

# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
for p in model.parameters():
nn.utils.clip_grad_norm_(p, 1.)
optimizer.step()

output = output.float()
loss = loss.float()
# measure accuracy and record loss
prec1 = accuracy(output.data, target)[0]
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

if i % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1))


def validate(val_loader, model, criterion, mc_eval=False):
"""
Run evaluation
"""
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
mi_meter = AverageMeter()

# switch to evaluate mode
model.eval()

end = time.time()
with torch.no_grad():
for i, (input, target) in enumerate(val_loader):
target = target.cuda()
input_var = input.cuda()
target_var = target.cuda()

if args.half:
input_var = input_var.half()

# compute output
if mc_eval:
outputs = F.softmax(model(input_var))
for _ in range(args.num_eval-1):
outputs += (F.softmax(model(input_var)))
output = outputs / float(args.num_eval)
else:
output = model(input_var)
loss = criterion(output, target_var)

output = output.float()
loss = loss.float()

# measure accuracy and record loss
prec1 = accuracy(output.data, target)[0]
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1))

print(' * Prec@1 {top1.avg:.3f}'
.format(top1=top1))

return top1.avg #, mi_meter.avg


if __name__ == '__main__':
main()
Loading

0 comments on commit de18ad1

Please sign in to comment.