Skip to content

Implementation of adversarial training methods and leverage them to train textCNN models

Notifications You must be signed in to change notification settings

joey0922/Adversarial-Train-TextCNN-Pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

26 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Adversarial-Train-TextCNN-Pytorch

A repository that implements three adversarial training methods, FGSM, "Free" and PGD, which are published in the paper, "Fast is better than free: Revisiting adversarial training", created by Eric Wong, Leslie Rice, and Zico Kolter. Then leverage them to train textCNN models, and show how they performance comparing with an original textCNN model trained in an ordinary way.

Implementation

By introducing perturbations into embeddings of inputs , adversarial training regularizes model's parameters to improve its robustness and generalization. It assumed that the perturbations on inputs wouldn't influence the distribution of outputs.

In the paper above, it presented three different adversaries as follows:

1. FGSM

FGSM, short for Fast Gradient Sign Method, is summarized below:

FGSM

The pseudo code above tells us that for each input x, how FGSM performs adversarial attack:

  1. Initialize perturbation from uniform distribution between - and .

  2. Add perturbation to input x then calculate the gradient of , and update the as below:

  1. If the absolute value of is too great, project it back into (-, ):

  1. Update model weights with some optimizer, e.g. SGD:

2. Free

The pseudo code of "Free" is as followings:

free

It can be regarded that Free adversarial repeats several FGSM attacks in one batch of x:

  1. Initialize to 0 before training starts.

  2. For each input x, Free adversary perform FGSM adversarial attack N times simultaneously.

  3. For every FGSM attack, It compute gradients for perturbation and model weights simultaneously:

  1. Update :

This formula is similar to FGSM's, the only difference is the coefficient of sign function.

  1. The same as FGSM, project back into (-, ) if its absolute value too great:

  1. Update model weights with some optimizer, e.g. SGD:

Note: Because Free adversary attack N times for each batch, epochs of training could decreased to T/N times.

3. PGD

PGD adversarial training updates perturbation N times before update model weights:

PGD

The steps are as followings:

  1. For each input x, PGD initialize to zero at first.

  2. Loop N times to update in the way below:

  1. If it exceeds the scale (-, ), it must be scaled again:

  1. After accumulate N times, update model weights with some optimizer:

To achieve the three adversarial trainings above , this repository includes two different implementations:

  1. encapsulate a adversarial training class to add perturbations of inputs when training, such as FGSM, PGD etc.
  2. create a textCNN class that is able to add perturbations when a instance calls forward function, such as Free etc.

Configuration

Environment

  • python=3.8.8
  • joblib==1.0.1
  • numpy==1.20.1
  • pandas==1.2.4
  • scikit_learn==1.0.2
  • torch==1.11.0
  • tqdm==4.59.0

Data

The train data is from this github site, and it consists of two hundred thousand new headlines from THUCNews. There are 10 classes, including finance, realty, stocks, education, science, society, politics, sports, game and entertainment, twenty thousand texts with length between 20 and 30 for each. The sheet below indicates how the data set is divided:

数量
训练集 180000
验证集 10000
测试集 10000
类目数 10

Inputs of model are characters of texts, and the pre-trained character embeddings come from sogou news. Click here to download them.

Result

Precision Recall F1 Accuracy
normal 91.54 91.54 1.83 91.54
FGSM 92.05 91.98 1.84 91.97
Free 90.14 89.94 1.80 89.94
PGD 92.10 92.03 1.84 92.03

According to the index listed above, it seems that PDG adversarial training has the best performance, but it takes at least twice longer to train models. The performance of model trained by FGSM is close to PSG's, however, it only takes half the time of PGD training.Though it takes twice the time of a model training in the ordinary way.

Reference

About

Implementation of adversarial training methods and leverage them to train textCNN models

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages