A repository that implements three adversarial training methods, FGSM, "Free" and PGD, which are published in the paper, "Fast is better than free: Revisiting adversarial training", created by Eric Wong, Leslie Rice, and Zico Kolter. Then leverage them to train textCNN models, and show how they performance comparing with an original textCNN model trained in an ordinary way.
By introducing perturbations into embeddings of inputs , adversarial training regularizes model's parameters to improve its robustness and generalization. It assumed that the perturbations on inputs wouldn't influence the distribution of outputs.
In the paper above, it presented three different adversaries as follows:
FGSM, short for Fast Gradient Sign Method, is summarized below:
The pseudo code above tells us that for each input x, how FGSM performs adversarial attack:
-
Initialize perturbation from uniform distribution between - and .
-
Add perturbation to input x then calculate the gradient of , and update the as below:
- Update model weights with some optimizer, e.g. SGD:
The pseudo code of "Free" is as followings:
It can be regarded that Free adversarial repeats several FGSM attacks in one batch of x:
-
For each input x, Free adversary perform FGSM adversarial attack N times simultaneously.
-
For every FGSM attack, It compute gradients for perturbation and model weights simultaneously:
This formula is similar to FGSM's, the only difference is the coefficient of sign function.
- Update model weights with some optimizer, e.g. SGD:
Note: Because Free adversary attack N times for each batch, epochs of training could decreased to T/N times.
PGD adversarial training updates perturbation N times before update model weights:
The steps are as followings:
To achieve the three adversarial trainings above , this repository includes two different implementations:
- encapsulate a adversarial training class to add perturbations of inputs when training, such as FGSM, PGD etc.
- create a textCNN class that is able to add perturbations when a instance calls forward function, such as Free etc.
- python=3.8.8
- joblib==1.0.1
- numpy==1.20.1
- pandas==1.2.4
- scikit_learn==1.0.2
- torch==1.11.0
- tqdm==4.59.0
The train data is from this github site, and it consists of two hundred thousand new headlines from THUCNews. There are 10 classes, including finance, realty, stocks, education, science, society, politics, sports, game and entertainment, twenty thousand texts with length between 20 and 30 for each. The sheet below indicates how the data set is divided:
数量 | |
---|---|
训练集 | 180000 |
验证集 | 10000 |
测试集 | 10000 |
类目数 | 10 |
Inputs of model are characters of texts, and the pre-trained character embeddings come from sogou news. Click here to download them.
Precision | Recall | F1 | Accuracy | |
---|---|---|---|---|
normal | 91.54 | 91.54 | 1.83 | 91.54 |
FGSM | 92.05 | 91.98 | 1.84 | 91.97 |
Free | 90.14 | 89.94 | 1.80 | 89.94 |
PGD | 92.10 | 92.03 | 1.84 | 92.03 |
According to the index listed above, it seems that PDG adversarial training has the best performance, but it takes at least twice longer to train models. The performance of model trained by FGSM is close to PSG's, however, it only takes half the time of PGD training.Though it takes twice the time of a model training in the ordinary way.