博客
关于我
bert 对抗训练实现代码
阅读量:750 次
发布时间:2019-03-23

本文共 7035 字,大约阅读时间需要 23 分钟。

前言

对抗训练是一种改进模型训练方式的方法,通过对抗性扰动来提高模型性能。本文首先简单介绍了 FGSM、PGD 和 FreeLB 三种常用的对抗训练方法,并提供了简单的实现代码示例。

说明

本文不详细讲解对抗训练的原理,重点在于实现。对抗训练的核心思想是通过生成对抗性扰动来加强模型的鲁棒性。虽然本文主要针对 embedding 层进行扰动,但具体操作可根据实际模型进行调整。

FGSM

FGSM(Fast Gradient Sign Method)是最早的对抗训练思路之一。其核心思想是通过计算 embedding 模型的梯度,生成对抗性扰动。

代码示例:

class FGM:    def __init__(self, model, emb_name, epsilon=1.0):        self.model = model        self.epsilon = epsilon        self.emb_name = emb_name        self.backup = {}    def attack(self):        for name, param in self.model.named_parameters():            if param.requires_grad and self.emb_name in name:                self.backup[name] = param.data.clone()                norm = torch.norm(param.grad)                if norm != 0:                    r_at = self.epsilon * param.grad / norm                    param.data.add_(r_at)    def restore(self):        for name, param in self.model.named_parameters():            if param.requires_grad and self.emb_name in name:                assert name in self.backup                param.data = self.backup[name]        self.backup = {}训练方法:fgm = FGM(model, epsilon=1, emb_name='word_embeddings')for batch_input, batch_label in processor:    loss = model(batch_input, batch_label)    loss.backward()    fgm.attack()    loss_adv = model(batch_input, batch_label)    loss_adv.backward()    fgm.restore()    optimizer.step()    model.zero_grad()

PGD

PGD(Projected Gradient Descent)是一种改进后的对抗训练方法,是对 FGSM 的扩展。它在多个迭代步骤中逐步生成对抗性扰动,每一步都对梯度进行修正和投影。

代码示例:

class PGD:    def __init__(self, model, emb_name, epsilon=1.0, alpha=0.3):        self.model = model        self.emb_name = emb_name        self.epsilon = epsilon        self.alpha = alpha        self.emb_backup = {}        self.grad_backup = {}    def attack(self, is_first_attack=False):        for name, param in self.model.named_parameters():            if param.requires_grad and self.emb_name in name:                if is_first_attack:                    self.emb_backup[name] = param.data.clone()                norm = torch.norm(param.grad)                if norm != 0:                    r_at = self.alpha * param.grad / norm                    param.data.add_(r_at)                    param.data = self.project(name, param.data, self.epsilon)    def restore(self):        for name, param in self.model.named_parameters():            if param.requires_grad and self.emb_name in name:                assert name in self.emb_backup                param.data = self.emb_backup[name]        self.emb_backup = {}    def project(self, param_name, param_data, epsilon):        r = param_data - self.emb_backup[param_name]        if torch.norm(r) > epsilon:            r = epsilon * r / torch.norm(r)        return self.emb_backup[param_name] + r    def backup_grad(self):        for name, param in self.model.named_parameters():            if param.requires_grad and param.grad is not None:                self.grad_backup[name] = param.grad.clone()    def restore_grad(self):        for name, param in self.model.named_parameters():            if param.requires_grad and param.grad is not None:                param.grad = self.grad_backup[name]训练方法:pgd = PGD(model, emb_name='word_embeddings', epsilon=1.0, alpha=0.3)K = 3for batch_input, batch_label in processor:    loss = model(batch_input, batch_label)    loss.backward()    pgd.backup_grad()    for t in range(K):        pgd.attack(is_first_attack=(t == 0))        if t != K-1:            model.zero_grad()        else:            pgd.restore_grad()        loss_adv = model(batch_input, batch_label)        loss_adv.backward()    pgd.restore()    optimizer.step()    model.zero_grad()

FreeLB

FreeLB(Free-form Labeling with Bottleneck)是一种对抗学习的半监督训练方法,主要用于标签数据不足的场景。

代码示例:

class FreeLB:    def __init__(self, adv_K=3, adv_lr=0.05, adv_init_mag=1000, adv_max_norm=0.0, adv_norm_type='l2'):        self.adv_K = adv_K        self.adv_lr = adv_lr        self.adv_init_mag = adv_init_mag        self.adv_max_norm = adv_max_norm        self.adv_norm_type = adv_norm_type    def attack(self, model, inputs, gradient_accumulation_steps=1):        input_ids = inputs['input_ids']        if isinstance(model, torch.nn.DataParallel):            embeds_init = getattr(model.module, self.base_model).embeddings.word_embeddings(input_ids)        else:            embeds_init = getattr(model, self.base_model).embeddings.word_embeddings(input_ids)        if self.adv_init_mag > 0:            input_mask = inputs['attention_mask'].to(embeds_init)            input_lengths = torch.sum(input_mask, 1)            if self.adv_norm_type == 'l2':                delta = torch.zeros_like(embeds_init).uniform_(-1, 1) * input_mask.unsqueeze(2)                dims = input_lengths * embeds_init.size(-1)                mag = self.adv_init_mag / torch.sqrt(dims)                delta = (delta * mag.view(-1, 1, 1)).detach()            elif self.adv_norm_type == 'linf':                delta = torch.zeros_like(embeds_init).uniform_(-self.adv_init_mag, self.adv_init_mag) * input_mask.unsqueeze(2)            else:                raise ValueError("Norm type {} not specified.".format(self.adv_norm_type))        for astep in range(self.adv_K):            delta.requires_grad_()            inputs['inputs_embeds'] = delta + embeds_init            inputs['input_ids'] = None            outputs = model(**inputs)            loss, logits = outputs[:2]            loss = loss.mean()            loss = loss / gradient_accumulation_steps            loss.backward()            delta_grad = delta.grad.clone().detach()            if self.adv_norm_type == 'l2':                denorm = torch.norm(delta_grad.view(delta_grad.size(0), -1), dim=1).view(-1, 1, 1)                denorm = torch.clamp(denorm, min=1e-8)                delta = (delta + self.adv_lr * delta_grad / denorm).detach()                if self.adv_max_norm > 0:                    delta_norm = torch.norm(delta.view(delta.size(0), -1).float(), p=2, dim=1).detach()                    exceed_mask = (delta_norm > self.adv_max_norm).to(embeds_init)                    reweights = (self.adv_max_norm / delta_norm * exceed_mask + (1 - exceed_mask)).view(-1, 1, 1)                    delta = (delta * reweights).detach()            elif self.adv_norm_type == 'linf':                denorm = torch.norm(delta_grad.view(delta_grad.size(0), -1), dim=1, p=float("inf")).view(-1, 1, 1)                denorm = torch.clamp(denorm, min=1e-8)                delta = (delta + self.adv_lr * delta_grad / denorm).detach()                if self.adv_max_norm > 0:                    delta = torch.clamp(delta, -self.adv_max_norm, self.adv_max_norm).detach()            else:                raise ValueError("Norm type {} not specified.".format(self.adv_norm_type))

训练方法:

FreeLB 是一个比较灵活的对抗训练方法,具体如何调用需要参考实现代码及文档。

效果展示

经过一系列实验和实际应用可以发现,对抗学习方法确实能够带来一定的性能提升。特别是在模型对抗性稳定性的方面表现优异。

注意事项

  • 不要忽视 embedding 层的梯度计算问题
  • 对抗训练方法的有效性在具体任务上可能存在差异
  • 选择合适的 epsilon 或 alpha 参数对训练效果有直接影响

参考资料

  • FGSM、PGD 和 FreeLB 的实现代码分别见相关开源项目
  • 对抗训练的理论理解可以参考相关论文和技术文档
  • Virtual Adversarial Training 是一种更先进的对抗学习方法,可见参考文献中相关研究成果

转载地址:http://vddzk.baihongyu.com/

你可能感兴趣的文章
Mysql 学习总结(87)—— Mysql 执行计划(Explain)再总结
查看>>
Mysql 学习总结(88)—— Mysql 官方为什么不推荐用雪花 id 和 uuid 做 MySQL 主键
查看>>
Mysql 学习总结(89)—— Mysql 库表容量统计
查看>>
mysql 实现主从复制/主从同步
查看>>
mysql 审核_审核MySQL数据库上的登录
查看>>
mysql 导入 sql 文件时 ERROR 1046 (3D000) no database selected 错误的解决
查看>>
mysql 导入导出大文件
查看>>
MySQL 导出数据
查看>>
mysql 将null转代为0
查看>>
mysql 常用
查看>>
MySQL 常用列类型
查看>>
mysql 常用命令
查看>>
Mysql 常见ALTER TABLE操作
查看>>
MySQL 常见的 9 种优化方法
查看>>
MySQL 常见的开放性问题
查看>>
Mysql 常见错误
查看>>
mysql 常见问题
查看>>
MYSQL 幻读(Phantom Problem)不可重复读
查看>>
mysql 往字段后面加字符串
查看>>
mysql 快照读 幻读_innodb当前读 与 快照读 and rr级别是否真正避免了幻读
查看>>