Model-Agnostic Meta-Learning[1](MAML)算法是一种模型无关的元学习算法,其模型无关体现在,能够与任何使用了梯度下降法的模型相兼容,广泛应用于各种不同的机器学习任务,包括分类、识别、强化学习等领域。
元学习的目标,是在大量不同的任务上训练一个模型,使其能够使用极少量的训练数据(即小样本),进行极少量的梯度下降步数,就能够迅速适应新任务,解决新问题。
在本项目复现的文献中,通过对模型参数进行显式训练,从而获得在各种任务下均能良好泛化的模型初始化参数。当面临小样本的新任务时,使用该初始化参数,能够在单步(或多步)梯度更新后,实现对该任务的学习和适配。为了复现文献中的实验结果,本项目基于paddlepaddle深度学习框架,在omniglot数据集上进行训练和测试,目标是达到并超过原文献的模型性能。
Omniglot 数据集包含50个不同的字母表,每个字母表中的字母各包含20个手写字符样本,每一个手写样本都是不同的人通过亚马逊的 Mechanical Turk 在线绘制的。Omniglot数据集的多样性强于MNIST数据集,是增强版的MNIST,常用与小样本识别任务。
考虑一个关于任务T的分布p(T),我们希望模型能够对该任务分布很好的适配。在K-shot(即K个学习样本)的学习任务下,从p(T)分布中随机采样一个新任务Ti,在任务Ti的样本分布qi中随机采样K个样本,用这K个样本训练模型,获得LOSS,实现对模型f的内循环更新。然后再采样query个样本,评估新模型的LOSS,然后对模型f进行外循环更新。反复上述过程,从而使最终模型能够对任务分布p(T)上的所有情况,能够良好地泛化。算法可用下图进行示意。
2.2 算法流程 MAML算法针对小样本图像分类任务的计算流程,如下图所示:
本项目的难点在于,算法包含外循环和内循环两种梯度更新方式。内循环针对每一种任务T进行梯度更新,用更新后的模型重新评估LOSS;而外循环则要使用内循环中更新后的LOSS,在所有任务上更新原始模型。 使用paddle经典的动态图框架,在内循环更新完成后,模型各节点参数已经发生变化,loss已无法反传到先前的模型参数上。外循环的参数更新公式为
这里,要使用θ_i^'参数模型计算的LOSS,反传回θ,使用经典动态图模型架构无法实现。本方案通过自定义参数的方式,使函数层层级联,实现更灵活的参数控制。
本项目提供了pycharm完整工程,以及notebook工程文件。使用时需要根据代码注释提前将数据集进行解压,并执行make_data.py生成训练集、验证集和测试集数据文件。
当然,也可以直接使用本项目生成好的文件。
之后根据文件名进行相应任务的训练和测试。本项目将训练和测试代码写在了一个文件中,如果用户需要直接借用已训练好的模型直接进行评估,可以将测试部分代码自行拷贝出来,单独运行测试代码。
可运行的notebook工程,参见:https://aistudio.baidu.com/aistudio/projectdetail/1869590