nlp中的对抗训练比较
创始人
2024-04-01 11:13:09
0

对抗训练是一种引入噪声的训练方式,可以对参数进行正则化,提升模型鲁棒性和泛化能力

常见的对抗训练有:fgsm、fgm、pgd、FreeAT、yopo、FreeLB、smart,AWP

这里给出fgsm、fgm、pgd、FreeAT的代码以及实验结果

仓库地址如下:GTyingzi/Compare_Adversial (github.com)

对抗训练代码

FGSM

官方实现
class FGSM:def __init__(self, model, eps=1):self.model = modelself.eps = epsself.backup = {}def attack(self, emb_name='embedding'):# emb_name这个参数要换成你模型中embedding的参数名for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:self.backup[name] = param.data.clone()r_at = self.eps * param.grad.sign()param.data.add_(r_at)def restore(self, emb_name='embedding'):# emb_name这个参数要换成你模型中embedding的参数名for name, para in self.model.named_parameters():if para.requires_grad and emb_name in name:assert name in self.backuppara.data = self.backup[name]self.backup = {}
实例
fgsm = FGSM(model=model)
for i,(trains,labels) in enumerate(train_iter):# 正常训练outputs = model(trains)loss = F.cross_entropy(outputs,labels)loss.backward() # 反向传播得到正常的grad# 对抗训练fgsm.attack() # 在embedding上添加对抗扰动outputs = model(trains)loss_adv = F.cross_entropy(outputs,labels)loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度fgsm.restore() # 恢复embedding参数# 梯度下降,更新参数optimizer.step()model.zero_grad()

FGM

官方实现
class FGM:def __init__(self, model, eps=1):self.model = modelself.eps = epsself.backup = {}def attack(self, emb_name='embedding'):# emb_name这个参数要换成你模型中embedding的参数名for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:self.backup[name] = param.data.clone()norm = torch.norm(param.grad)if norm and not torch.isnan(norm):r_at = self.eps * param.grad / normparam.data.add_(r_at)def restore(self, emb_name='embedding'):# emb_name这个参数要换成你模型中embedding的参数名for name, para in self.model.named_parameters():if para.requires_grad and emb_name in name:assert name in self.backuppara.data = self.backup[name]self.backup = {}
实例
fgm = FGM(model=model)
for i,(trains,labels) in enumerate(train_iter):# 正常训练outputs = model(trains)loss = F.cross_entropy(outputs,labels)loss.backward() # 反向传播得到正常的grad# 对抗训练fgm.attack() # 在embedding上添加对抗扰动outputs = model(trains)loss_adv = F.cross_entropy(outputs,labels)loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度fgm.restore() # 恢复embedding参数# 梯度下降,更新参数optimizer.step()model.zero_grad()

PGD

官方实现
# PGD
class PGD:def __init__(self, model, eps=1, alpha=0.3):self.model = modelself.eps = epsself.alpha = alphaself.emb_backup = {}self.grad_backup = {}def attack(self, emb_name='embedding', is_first_attack=False):for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:if is_first_attack:self.emb_backup[name] = param.data.clone()norm = torch.norm(param.grad)if norm != 0 and not torch.isnan(norm):r_at = self.alpha * param.grad / normparam.data.add_(r_at)param.data = self.project(name, param.data)def restore(self, emb_name='embedding'):for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:assert name in self.emb_backupparam.data = self.emb_backup[name]self.emb_backup = {}def project(self, param_name, param_data):r = param_data - self.emb_backup[param_name]if torch.norm(r) > self.eps:r = self.eps * r / torch.norm(r)return self.emb_backup[param_name] + rdef backup_grad(self):for name, param in self.model.named_parameters():if param.requires_grad and param.grad is not None:self.grad_backup[name] = param.grad.clone()def restore_grad(self):for name, param in self.model.named_parameters():if param.requires_grad and param.grad is not None:param.grad = self.grad_backup[name]
实例
pgd = PGD(model=model)
for i,(trains,labels) in enumerate(train_iter):# 正常训练outputs = model(trains)loss = F.cross_entropy(outputs,labels)loss.backward() # 反向传播得到正常的grad# 对抗训练pgd_k = 3for _t in range(pgd_k):pgd.attack(is_first_attack=(_t == 0))# 在embedding上添加对抗扰动, first attack时备份param.dataif _t != pgd_k - 1:model.zero_grad()else:pgd.restore_grad()outputs = model(trains)loss_adv = F.cross_entropy(outputs,labels)loss_adv.backward()# 反向传播,并在正常的grad基础上,累加对抗训练的梯度pgd.restore()# 恢复embedding参数# 梯度下降,更新参数optimizer.step()model.zero_grad()

FreeAT

官方实现
class FreeAT:def __init__(self, model, eps=0.1):self.model = modelself.eps = epsself.emb_backup = {}self.grad_backup = {}self.last_r_at = 0def attack(self, emb_name='embedding', is_first_attack=False):for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:if is_first_attack:self.emb_backup[name] = param.data.clone()param.data.add_(self.last_r_at)param.data = self.project(name, param.data)self.last_r_at = self.last_r_at + self.eps * param.grad.sign()def restore(self, emb_name='embedding'):for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:assert name in self.emb_backupparam.data = self.emb_backup[name]self.emb_backup = {}def project(self, param_name, param_data):r = param_data - self.emb_backup[param_name]if torch.norm(r) > self.eps:r = self.eps * r / torch.norm(r)return self.emb_backup[param_name] + rdef backup_grad(self):for name, param in self.model.named_parameters():if param.requires_grad and param.grad is not None:self.grad_backup[name] = param.grad.clone()def restore_grad(self):for name, param in self.model.named_parameters():if param.requires_grad and param.grad is not None:param.grad = self.grad_backup[name]
实例
free_at = FreeAT(model=model)
for i,(trains,labels) in enumerate(train_iter):# 正常训练outputs = model(trains)loss = F.cross_entropy(outputs,labels)loss.backward() # 反向传播得到正常的grad# 对抗训练m = 5for _t in range(m):free_at.attack(is_first_attack=(_t == 0))# 在embedding上添加对抗扰动, first attack时备份param.dataif _t != pgd_k - 1:model.zero_grad()else:free_at.restore_grad()outputs = model(trains)loss_adv = F.cross_entropy(outputs,labels)loss_adv.backward()# 反向传播,并在正常的grad基础上,累加对抗训练的梯度free_at.restore()# 恢复embedding参数# 梯度下降,更新参数optimizer.step()model.zero_grad()

测试结果区

TextCNN + attack_train

baseline+attack_trainprecisionrecallF1
TextCNN0.90830.90780.9079
TextCNN + FGSM0.91050.91030.9103
TextCNN + FGM0.91100.91040.9105
TextCNN + PGD0.91030.90980.9099
TextCNN + FreeAT0.91040.90970.9096

TextRNN++ attack_train

baseline+attack_trainprecisionrecallF1
TextRNN0.90460.90340.9038
TextRNN + FGSM0.90680.90550.9058
TextRNN + FGM0.91600.91610.9160
TextRNN + PGD0.91440.91420.9140
TextRNN + FreeAT0.90640.90620.9059

参考资料:

attack_train/Attack-Train-Compare-Pytorch at main · tanshoudong/attack_train (github.com)

(517条消息) 对抗训练fgm、fgsm和pgd原理和源码分析_谈笑风生…的博客-CSDN博客_pgd代码

一文搞懂NLP中的对抗训练FGSM/FGM/PGD/FreeAT/YOPO/FreeLB/SMART - 知乎 (zhihu.com)

(22条消息) 对抗学习总结:FGSM->FGM->PGD->FreeAT, YOPO ->FreeLb->SMART->LookAhead->VAT_zhurui_xiaozhuzaizai的博客-CSDN博客

lonePatient/TorchBlocks: A PyTorch-based toolkit for natural language processing (github.com)

相关内容

热门资讯

监控摄像头接入GB28181平... 流程简介将监控摄像头的视频在网站和APP中直播,要解决的几个问题是:1&...
Windows10添加群晖磁盘... 在使用群晖NAS时,我们需要通过本地映射的方式把NAS映射成本地的一块磁盘使用。 通过...
protocol buffer... 目录 目录 什么是protocol buffer 1.protobuf 1.1安装  1.2使用...
在Word、WPS中插入AxM... 引言 我最近需要写一些文章,在排版时发现AxMath插入的公式竟然会导致行间距异常&#...
Fluent中创建监测点 1 概述某些仿真问题,需要创建监测点,用于获取空间定点的数据࿰...
educoder数据结构与算法...                                                   ...
MySQL下载和安装(Wind... 前言:刚换了一台电脑,里面所有东西都需要重新配置,习惯了所...
MFC文件操作  MFC提供了一个文件操作的基类CFile,这个类提供了一个没有缓存的二进制格式的磁盘...
有效的括号 一、题目 给定一个只包括 '(',')','{','}'...
【PdgCntEditor】解... 一、问题背景 大部分的图书对应的PDF,目录中的页码并非PDF中直接索引的页码...