本文将不确定性引入神经网络,将确定性参数的神经网络改造为具有随机特性的概率神经网络(也成贝叶斯神经网络)。本文是贝叶斯神经网络的奠基作之一,具有很高的引用量。
具体地,在传统神经网络中,各网络节点的参数为确定值;通过本文方法引入不确定性后,各网络节点的参数转变为满足概率分布的随机变量。 每次正向推理时,网络会根据概率分布对参数值进行采样,并以采样到的值作为本次正向推理的参数值。传统神经网络与贝叶斯神经网络的异同点如下图所示:
训练贝叶斯神经网络时,通过本文方法,可将loss函数反向传递到网络节点的概率分布参数上,从而动态调优该网络。
论文链接:Weight Uncertainty in Neural Networks
基于paddlepaddle深度学习框架,对文献算法进行复现后,本项目达到的测试精度,如下表所示。 参考文献的最高精度为98.68%
模型和方法 | 本项目精度 |
---|---|
lenet-bbb | 98.75% |
alexnet-bbb | 98.73% |
3conv3fc-bbb | 99.07% |
lenet-lrt | 98.76% |
alexnet-lrt | 98.69% |
3conv3fc-lrt | 99.29% |
超参数配置如下:
超参数名 | 设置值 |
---|---|
lr | 0.01 |
batch_size | 256 |
epochs | 200 |
本项目使用的是MNIST数据集。该数据集为美国国家标准与技术研究所(National Institute of Standards and Technology (NIST))发起整理,一共统计了来自250个不同的人手写数字图片,其中50%是高中生,50%来自人口普查局的工作人员。该数据集的收集目的是希望通过算法,实现对手写数字的识别。
- 数据集大小:
- MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片。
- 数据格式:它包含了四个部分
- (1)Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
- (2)Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
- (3)Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
- (4)Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)
数据集链接:MNIST
-
硬件:
- x86 cpu
- NVIDIA GPU
-
框架:
- PaddlePaddle = 2.1.2
-
其他依赖项:
- numpy==1.19.3
- matplotlib==3.3.4
- pandas==1.2.4
- pytest==6.2.4
- paddle==1.0.2
- Pillow==8.3.1
python train.py --net_type 3conv3fc --dataset MNIST
训练贝叶斯神经网络,运行完毕后,模型参数文件保存在./checkpoints/MNIST/bayesian目录下。
python test.py --net_type 3conv3fc --dataset MNIST
用于测试贝叶斯神经网络,测试前,将已训练好的最优参数模型从./results/3CONV3FC拷贝至./checkpoints/MNIST/bayesian
├── onfig_bayesian.py # 配置
├── metrics.py # 度量相关
├── README.md # readme
├── requirements.txt # 依赖
├── test # 测试
├── train # 启动训练入口
├── utils.py # 公共调用
├── checkpoints # 保存
│ ├── MNIST # 数据集名称
│ ├── bayesian
│ ├── best
├── data
│ ├── data.py
├── layers
│ ├── misc.py
│ ├── BBB
│ ├── BBBConv.py
│ ├── BBBLinear.py
├── models
│ ├── BayesianModels
│ ├── BayesianOriginNet.py
│ ├── BayesianLeNet.py
可以在 train.py
中设置训练与评估相关参数,具体如下:
参数 | 默认值 | 说明 | 其他 |
---|---|---|---|
--net_type | 3conv3fc, 可选 | 选择模型 | 可选择lenet/alexnet/3conv3fc/originet |
--dataset | MNIST, 可选 | 选择数据集 | 本项目目前仅支持MNIST |
可参考快速开始章节中的描述
执行训练开始后,将得到类似如下的输出。每一轮epoch
训练将会打印当前training loss、training acc、val loss、val acc以及训练kl散度。
Epoch: 0 Training Loss: 957661.3024 Training Accuracy: 0.5314 Validation Loss: 6048323.2596 Validation Accuracy: 0.8872 train_kl_div: 108218176.5714
Validation loss decreased (inf --> 6048323.259558). Saving model ...
Epoch: 1 Training Loss: 620338.8870 Training Accuracy: 0.7838 Validation Loss: 4819156.8720 Validation Accuracy: 0.8885 train_kl_div: 90394454.2449
Validation loss decreased (6048323.259558 --> 4819156.872046). Saving model ...
Epoch: 2 Training Loss: 483882.8229 Training Accuracy: 0.8268 Validation Loss: 3822200.3844 Validation Accuracy: 0.8913 train_kl_div: 71920784.3061
Validation loss decreased (4819156.872046 --> 3822200.384351). Saving model ...
Epoch: 3 Training Loss: 390434.6679 Training Accuracy: 0.8332 Validation Loss: 2554367.9053 Validation Accuracy: 0.9018 train_kl_div: 48361270.8571
Validation loss decreased (3822200.384351 --> 2554367.905322). Saving model ...
Epoch: 4 Training Loss: 275255.5825 Training Accuracy: 0.8434 Validation Loss: 1809289.2232 Validation Accuracy: 0.9172 train_kl_div: 34525619.3469
可参考快速开始章节中的描述
此时的输出为:
Testing Accuracy: 0.9907
在不同的超参数配置下,模型的收敛效果、达到的精度指标有较大的差异,以下列举不同超参数配置下,实验结果的差异性,便于比较分析:
(1)学习率:
原文献采用的优化器与本项目一致,为Adam优化器,原文献学习率设置为0.001,本项目经调参发现, 学习率设置为0.01或0.0001时,网络有时会不收敛,该模型的稳定性存在可改进空间。
(2)epoch轮次
本项目训练时,采用的epoch轮次为200。LOSS和准确率在110个epoch附近已趋于稳定,模型处于收敛状态,下图为3CONV3FC-BBB的训练曲线。
本项目复现时遇到一个比较大的问题,用pytorch顺利跑通源代码后,修改至paddle框架下再次训练,发现模型不收敛,训练准确率一直维持在0.1附近(随机挑选概率), 模型完全没有学到东西。
针对此问题,我依次对dataset、dataloader、模型参数初始化、优化器、loss函数,甚至沿着整个计算图跟踪了梯度是否正确传递。 最终定位为paddle.max函数,使用该函数后,问题出现;屏蔽该函数后,问题消失。 经分析,应该是该函数不连续,不支持梯度传递。而pytorch版本的max函数则没有此问题。
训练完成后,模型保存在checkpoints目录下。
训练和测试日志保存在results目录下。
信息 | 说明 |
---|---|
发布者 | hrdwsong |
时间 | 2021.08 |
框架版本 | Paddle 2.1.2 |
应用场景 | 贝叶斯神经网络 |
支持硬件 | GPU、CPU |
Aistudio地址 | https://aistudio.baidu.com/aistudio/projectdetail/2301297?shared=1 |