prototypical networks属于小样本学习中基于度量的算法。它将support set的同一类样本的所有embedding的质心作为这一类的prototype。query直接用自己的embedding去找最近的prototype,从而完成分类。
详细分析算法流程: 输入训练集,其中是标签,是标签的子训练集。
在每个episode中,随机选取个样本作为support set ,同时选取同一类的不同的个样本作为query set ,通过 计算每个类的原型向量。并依此对所有query set的样本进行分类,通过得到的损失更新网络。
-
[Python-3.7]
-
[PaddlePaddle-2.1.2]
-
[Cuda-11.2 and Cudnn-8.1]
(一)数据处理
- 数据集下载
- 数据集预处理: 需要将
miniImageNet
数据集处理为如下格式:所有图像放到../miniImageNet/images/
目录下,然后将本库中data/split/
下的csv文件放到../miniImageNet/
目录下,最后修改train.py
和test.py
中的data_root
参数为miniImageNet
目录所在路径。
(二)修改参数配置
epoch
训练的epoch个数lr
学习率data_root
数据所在路径episode_size
每次优化的任务数train_episode
在训练阶段每个epoch的任务数test_episode
在测试阶段每个epoch的任务数way_num
训练时类别数量shot_num
支持集中每个类所含样本数量query_num
查询集中每个类所含样本数量test_way
测试时类别数量use_memory
是否使用内存
(三)模型训练
$ python /script/train_5w_1s.sh 用于训练5way-1shot
$ python /script/train_5w_5s.sh 用于训练5way-5shot
(四)模型测试
$ python /script/test_5w_1s.sh 用于测试5way-1shot
$ python /script/test_5w_5s.sh 用于测试5way-5shot
我们复现的方法在5way-1shot上超过论文的性能0.32%,在5way-5shot上超过论文的性能0.18%。并且我们的复现是使用6000个任务进行测试得到的结果,原文中使用的是600个任务,因此我们的结果具有更高的可信度。
miniImageNet Dataset
Setups | 1-Shot 5-Way | 5-Shot 5-Way |
---|---|---|
Paper | 49.42±0.78 | 68.20±0.66 |
Ours | 49.74±0.25 | 68.38±0.21 |
@inproceedings{snell2017prototypical,
title={Prototypical networks for few-shot learning},
author={Snell, Jake and Swersky, Kevin and Zemel, Richard},
booktitle={Proceedings of the 31st International Conference on Neural Information Processing Systems},
pages={4080--4090},
year={2017}
}