Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

新增attack的方法和用例 #72

Merged
merged 7 commits into from
Jul 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ark_nlp/factory/loss_function/global_pointer_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ def forward(self, logits, target):
logits: [N, C, L, L]
"""
bh = logits.shape[0] * logits.shape[1]
target = torch.reshape(target.to_dense(), (bh, -1))
target = torch.reshape(target, (bh, -1))
logits = torch.reshape(logits, (bh, -1))
return torch.mean(GlobalPointerCrossEntropy.multilabel_categorical_crossentropy(target, logits))
4 changes: 4 additions & 0 deletions ark_nlp/factory/utils/attack/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from ark_nlp.factory.utils.attack.fgm import FGMAttackMixin
from ark_nlp.factory.utils.attack.pgd import PGDAttackMixin
from ark_nlp.factory.utils.attack.freelb import FreeLBAttackMixin
from ark_nlp.factory.utils.attack.awp import AWPAttackMixin
159 changes: 159 additions & 0 deletions ark_nlp/factory/utils/attack/awp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import torch
import warnings
from torch.utils.data import DataLoader

from ark_nlp.factory.optimizer import get_optimizer


class AWP(object):
"""
基于AWP算法的攻击机制

Args:
module (:obj:`torch.nn.Module`): 模型

Reference:
[1] [Adversarial weight perturbation helps robust generalization](https://arxiv.org/abs/2004.05884)
"""
def __init__(self, module):
self.module = module
self.param_backup = {}
self.param_backup_eps = {}
self.grad_backup = {}

def attack(
self,
epsilon=0.001,
alpha=1.0,
emb_name='weight',
is_first_attack=False
):
if alpha == 0: return
e = 1e-6
for name, param in self.module.named_parameters():
if param.requires_grad and param.grad is not None and emb_name in name:
# save
if is_first_attack:
self.param_backup[name] = param.data.clone()
grad_eps = epsilon * param.abs().detach()
self.param_backup_eps[name] = (
self.param_backup[name] - grad_eps,
self.param_backup[name] + grad_eps,
)
# attack
norm1 = torch.norm(param.grad)
norm2 = torch.norm(param.data.detach())
if norm1 != 0 and not torch.isnan(norm1):
r_at = alpha * param.grad / (norm1 + e) * (norm2 + e)
param.data.add_(r_at)
param.data = torch.min(
torch.max(
param.data,
self.param_backup_eps[name][0]
),
self.param_backup_eps[name][1]
)

def restore(self):
for name, param in self.module.named_parameters():
if name in self.param_backup:
param.data = self.param_backup[name]
self.param_backup = {}
self.param_backup_eps = {}

def backup_grad(self):
for name, param in self.module.named_parameters():
if param.requires_grad and param.grad is not None:
self.grad_backup[name] = param.grad.clone()

def restore_grad(self):
for name, param in self.module.named_parameters():
if name in self.grad_backup:
param.grad = self.grad_backup[name]
self.grad_backup = {}


class AWPAttackMixin(object):
def _on_train_begin(
self,
train_data,
validation_data,
batch_size,
lr,
params,
shuffle,
num_workers=0,
train_to_device_cols=None,
**kwargs
):
if hasattr(train_data, 'id2cat'):
self.id2cat = train_data.id2cat
self.cat2id = {v_: k_ for k_, v_ in train_data.id2cat.items()}

# 在初始化时会有class_num参数,若在初始化时不指定,则在训练阶段从训练集获取信息
if self.class_num is None:
if hasattr(train_data, 'class_num'):
self.class_num = train_data.class_num
else:
warnings.warn("The class_num is None.")

if train_to_device_cols is None:
self.train_to_device_cols = train_data.to_device_cols
else:
self.train_to_device_cols = train_to_device_cols

train_generator = DataLoader(
train_data,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn=self._train_collate_fn
)
self.train_generator_lenth = len(train_generator)

self.optimizer = get_optimizer(self.optimizer, self.module, lr, params)
self.optimizer.zero_grad()

self.module.train()

self.awp = AWP(self.module)
self.awp_k = 3

self._on_train_begin_record(**kwargs)

return train_generator

def _on_backward(
self,
inputs,
outputs,
logits,
loss,
gradient_accumulation_steps=1,
**kwargs
):

# 如果GPU数量大于1
if self.n_gpu > 1:
loss = loss.mean()
# 如果使用了梯度累积,除以累积的轮数
if gradient_accumulation_steps > 1:
loss = loss / gradient_accumulation_steps

loss.backward()

self.awp.backup_grad()
for t in range(self.awp_k):
self.awp.attack(is_first_attack=(t == 0))
if t != self.awp_k - 1:
self.optimizer.zero_grad()
else:
self.awp.restore_grad()
logits = self.module(**inputs)
_, attck_loss = self._get_train_loss(inputs, logits, **kwargs)
attck_loss.backward()
self.awp.restore()

self._on_backward_record(loss, **kwargs)

return loss
140 changes: 140 additions & 0 deletions ark_nlp/factory/utils/attack/fgm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import torch
import warnings
from torch.utils.data import DataLoader

from ark_nlp.factory.optimizer import get_optimizer


class FGM(object):
"""
基于FGM算法的攻击机制

Args:
module (:obj:`torch.nn.Module`): 模型

Examples::

>>> # 初始化
>>> fgm = FGM(module)
>>> for batch_input, batch_label in data:
>>> # 正常训练
>>> loss = module(batch_input, batch_label)
>>> loss.backward() # 反向传播,得到正常的grad
>>> # 对抗训练
>>> fgm.attack() # 在embedding上添加对抗扰动
>>> loss_adv = module(batch_input, batch_label)
>>> loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
>>> fgm.restore() # 恢复embedding参数
>>> # 梯度下降,更新参数
>>> optimizer.step()
>>> optimizer.zero_grad()

Reference:
[1] https://zhuanlan.zhihu.com/p/91269728
"""
def __init__(self, module):
self.module = module
self.backup = {}

def attack(
self,
epsilon=1.,
emb_name='word_embeddings'
):
for name, param in self.module.named_parameters():
if param.requires_grad and emb_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0 and not torch.isnan(norm):
r_at = epsilon * param.grad / norm
param.data.add_(r_at)

def restore(
self,
emb_name='word_embeddings'
):
for name, param in self.module.named_parameters():
if param.requires_grad and emb_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}


class FGMAttackMixin(object):

def _on_train_begin(
self,
train_data,
validation_data,
batch_size,
lr,
params,
shuffle,
num_workers=0,
train_to_device_cols=None,
**kwargs
):
if hasattr(train_data, 'id2cat'):
self.id2cat = train_data.id2cat
self.cat2id = {v_: k_ for k_, v_ in train_data.id2cat.items()}

# 在初始化时会有class_num参数,若在初始化时不指定,则在训练阶段从训练集获取信息
if self.class_num is None:
if hasattr(train_data, 'class_num'):
self.class_num = train_data.class_num
else:
warnings.warn("The class_num is None.")

if train_to_device_cols is None:
self.train_to_device_cols = train_data.to_device_cols
else:
self.train_to_device_cols = train_to_device_cols

train_generator = DataLoader(
train_data,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn=self._train_collate_fn
)
self.train_generator_lenth = len(train_generator)

self.optimizer = get_optimizer(self.optimizer, self.module, lr, params)
self.optimizer.zero_grad()

self.module.train()

self.fgm = FGM(self.module)

self._on_train_begin_record(**kwargs)

return train_generator

def _on_backward(
self,
inputs,
outputs,
logits,
loss,
gradient_accumulation_steps=1,
**kwargs
):

# 如果GPU数量大于1
if self.n_gpu > 1:
loss = loss.mean()
# 如果使用了梯度累积,除以累积的轮数
if gradient_accumulation_steps > 1:
loss = loss / gradient_accumulation_steps

loss.backward()

self.fgm.attack()
logits = self.module(**inputs)
_, attck_loss = self._get_train_loss(inputs, logits, **kwargs)
attck_loss.backward()
self.fgm.restore()

self._on_backward_record(loss, **kwargs)

return loss
Loading