本项目基于TextCNN,测试三种对抗训练模型(FGSM,PGD,FREE)在text classification上的表现。主要参考论文Fast is better than free: Revisiting adversarial training涉及的三个对抗训练方法:FGSM(Fast Gradient Sign Method)、PGD(projected gradient decent)、FREE(Free adversarial based on FGSM)。这三种方法主要差异在于delta、alpha参数的初始化和更新方式上,其差异性可以见下面三个模型对应的伪代码。
python3.7
torch 1.8.0+cu111
scikit-learn
scipy
numpy
本实验同样是使用THUCNews的一个子集进行训练与测试,数据集请自行到THUCTC:一个高效的中文文本分类工具包下载,请遵循数据提供方的开源协议;
文本类别涉及10个类别:categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐'];
cnews.train.txt: 训练集(5000*10)
cnews.val.txt: 验证集(500*10)
cnews.test.txt: 测试集(1000*10)
训练所用的数据,以及训练好的词向量可以下载:链接: https://pan.baidu.com/s/1DOgxlY42roBpOKAMKPPKWA,密码: up9d
本实验在文本按字级别Embedding上,利用对抗训练方法产生attack后,然后再加入embedding中,最后利用cnn来进行文本特征学习。其实现部分代码如下:
def forward(self, inputs_ids,attack=None,is_training=True):
embs = self.embedding(inputs_ids)
if attack is not None:
embs=embs+attack #加入干扰信息
embs=embs.unsqueeze(1)
....
out = self.fc2(fc)
return out
此外,论文涉及到三个超参数分别为epsilon、alpha、attack_iters,同时epsilon会控制delta参数的生成。本实验中,根据论文在图像领域的设定,以及自己主观经验判断和小范围的网格搜索方式设定以下具体值。相比论文中根据数据分布的均值和方差方式来设定,本项目显得更粗糙些,但为了达到只是实验对比的目的,此设定也是有效的。
epsilon = torch.tensor(0.1)
alpha= 0.04
attack_iters=5
首先在config.py中选择要运行的mode,训练与测试分别执行如下:
python run.py train
python run.py test
四种模型训练20轮后,在测试集上实验结果如下:
Model | Accuracy | Precision | Recall | F1-score |
---|---|---|---|---|
TextCNN | 95.14 | 95.16 | 95.14 | 95.11 |
FGSM | 95.53 | 95.60 | 95.53 | 95.50 |
PGD | 95.63 | 95.67 | 95.63 | 95.60 |
FREE | 95.49 | 95.54 | 94.49 | 95.46 |
四种模型训练消耗的时间(minutes)对比如下:
model | total_cost | mean_cost |
---|---|---|
TextCNN | 3.7 | 0.185 |
FGSM | 5.83 | 0.2915 |
PGD | 13.54 | 0.677 |
FREE | 12.22 | 0.611 |
通过本次实验,有以下几点结论与想法:
- 对抗训练技术方法的确有助于提高文本分类任务的效果;
- FGSM方法虽然提高了训练效率,但并不影响推理速度,而且NLP领域任务都不用大的轮数,所以PGD方法更合适些;
- 三种方法涉及delta、alpha超参数的初始化设定,面临不同的任务,会有变动,增加探寻合适参数的难度;
- 在文本分类中,觉得用word2vec或者bert方式初始化向量来进行干扰样本生成,应会比随机初始化embedding方式更合适,而且可以根据高频率词的分布来初始化delta、alpha更合理;
- 若在本论文提出的改进版FGSM基础上班,考虑如何更稳定或自动化的方式初始化delta等参数,是一个值得优化的方向。
1.FAST IS BETTER THAN FREE: REVISITING ADVERSARIAL TRAINING
2.https://github.com/locuslab/fast_adversarial
3.Adversarial Training Methods for Semi-Supervised Text Classification.
4.[论文笔记] Projected Gradient Descent (PGD)