A plug-and-play implementation for Bayesian fine-tuning to practically learn Bayesian Neural Networks
We provide a Pytorch implementation to learn Bayesian Neural Networks (BNNs) at low cost. We unfold the learning of a BNN into two steps: deterministic pre-training of the deep neural network (DNN) counterpart of the BNN followed by Bayesian fine-tuning.
For deterministic pre-training, we just train a regular DNN via maximum a posteriori (MAP) estimation, which is realised by conducting optimization under weight decay regularizor. We can also reuse off-the-shelf pre-trained models from popular model zoos (e.g., PyTorch Hub).
After deterministic pre-training, it is straight forward to convert the converged DNN into a BNN and to perform Bayesian fine-tuning given this library.
The current implementation only considers using mean-field Gaussian as approximate posterior, and more flexible distributions are under development.
For more details, refer to our BayesAdapter paper and GitHub page.
We apply the Bayesian fine-tuning paradigm to detect adversarial examples, and observe promising results. See our LiBRe paper for more details. To reproduce LiBRe, check here.
- python 3
- torch 1.3.0+
- torchvision 0.4.1+
pip install git+https://github.com/thudzj/ScalableBDL.git
With CIFAR-10 classification as an example, we can easily leverage this library to perform Bayesian fine-tuning upon a pre-trained wide-ResNet-28-10 model, and to evaluate the resultant Bayesian posterior.
We first import the necessary modules for Bayesian fine-tuning:
from scalablebdl.bnn_utils import freeze, unfreeze, disable_dropout, Bayes_ensemble
from scalablebdl.prior_reg import PriorRegularizor
from scalablebdl.mean_field import PsiSGD, to_bayesian, to_deterministic
Then load the pre-trained wide-ResNet model, and disable the possible stochasticity inside the model:
net = wrn(pretrained=True, depth=28, width=10).cuda()
disable_dropout(net)
We can check the performence of such a model by one-sample Bayes ensemble as it is deterministic:
eval_loss, eval_acc = Bayes_ensemble(test_loader, net,
num_mc_samples=1)
print('Results of deterministic pre-training, '
'eval loss {}, eval acc {}'.format(eval_loss, eval_acc))
To expand the point-estimate parameters into Bayesian variables, we only need to invoke
bayesian_net = to_bayesian(net, num_mc_samples=args.num_mc_samples)
unfreeze(bayesian_net)
To realise fine-tuning, we build two optimizers with inherent weight decay modules for the mean and variance of the approximate posterior:
mus, psis = [], []
for name, param in bayesian_net.named_parameters():
if 'psi' in name: psis.append(param)
else: mus.append(param)
optimizer = SGD([{"params": mus, "lr": 0.0008, "weight_decay": 2e-4},
{"params": psis, "lr": 0.1, "weight_decay": 0}],
momentum=0.9, nesterov=True)
regularizer = PriorRegularizor(bayesian_net, decay=2e-4, num_data=50000,
num_mc_samples=args.num_mc_samples)
The regularizer
absorbs the KL divergence between the approximate posterior and the prior.
After the preparation, we perform Bayesian fine-tuning just like fine-tuning a regular DNN, expect that our optimization involves two optimizers:
for epoch in range(args.epochs):
bayesian_net.train()
for i, (input, target) in enumerate(train_loader):
input = input.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
output = bayesian_net(input)
loss = torch.nn.functional.cross_entropy(output, target)
optimizer.zero_grad()
loss.backward()
regularizer.step()
optimizer.step()
if i % 100 == 0:
print("Epoch {}, ite {}/{}, loss {}".format(epoch, i,
len(train_loader), loss.item()))
eval_loss, eval_acc = Bayes_ensemble(test_loader, bayesian_net)
print("Epoch {}, eval loss {}, eval acc {}".format(
epoch, eval_loss, eval_acc))
Check this for a complete and runnable script.
We compare the predictive performance between the fine-tuning start point (DNN) and the obtained BNN in the following table. Note that we perform Bayes ensemble with 100 MC samples for estimating the accuracy of BNN.
CIFAR-10 (wide-ResNet-28-10) | ImageNet (ResNet-50) | |
---|---|---|
DNN | 96.92% | 76.13% |
BNN | 97.09% | 76.49% |
- @Harry24k github:bayesian-neural-network-pytorch
If you have any problem about this library or want to contribute to it, please send us an Email at:
Please cite our paper if you use this code in your own work:
@article{deng2020bayesadapter,
title={BayesAdapter: Being Bayesian, Inexpensively and Reliably, via Bayesian Fine-tuning},
author={Deng, Zhijie and Zhang, Hao and Yang, Xiao and Dong, Yinpeng and Zhu, Jun},
journal={arXiv preprint arXiv:2010.01979},
year={2020}
}
@inproceedings{deng2021libre,
title={LiBRe: A Practical Bayesian Approach to Adversarial Detection},
author={Deng, Zhijie and Yang, Xiao and Xu, Shizhen and Su, Hang and Zhu, Jun},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={972--982},
year={2021}
}