本文共 7035 字,大约阅读时间需要 23 分钟。
前言
对抗训练是一种改进模型训练方式的方法,通过对抗性扰动来提高模型性能。本文首先简单介绍了 FGSM、PGD 和 FreeLB 三种常用的对抗训练方法,并提供了简单的实现代码示例。
说明
本文不详细讲解对抗训练的原理,重点在于实现。对抗训练的核心思想是通过生成对抗性扰动来加强模型的鲁棒性。虽然本文主要针对 embedding 层进行扰动,但具体操作可根据实际模型进行调整。
FGSM
FGSM(Fast Gradient Sign Method)是最早的对抗训练思路之一。其核心思想是通过计算 embedding 模型的梯度,生成对抗性扰动。
代码示例:
class FGM: def __init__(self, model, emb_name, epsilon=1.0): self.model = model self.epsilon = epsilon self.emb_name = emb_name self.backup = {} def attack(self): for name, param in self.model.named_parameters(): if param.requires_grad and self.emb_name in name: self.backup[name] = param.data.clone() norm = torch.norm(param.grad) if norm != 0: r_at = self.epsilon * param.grad / norm param.data.add_(r_at) def restore(self): for name, param in self.model.named_parameters(): if param.requires_grad and self.emb_name in name: assert name in self.backup param.data = self.backup[name] self.backup = {}训练方法:fgm = FGM(model, epsilon=1, emb_name='word_embeddings')for batch_input, batch_label in processor: loss = model(batch_input, batch_label) loss.backward() fgm.attack() loss_adv = model(batch_input, batch_label) loss_adv.backward() fgm.restore() optimizer.step() model.zero_grad()
PGD
PGD(Projected Gradient Descent)是一种改进后的对抗训练方法,是对 FGSM 的扩展。它在多个迭代步骤中逐步生成对抗性扰动,每一步都对梯度进行修正和投影。
代码示例:
class PGD: def __init__(self, model, emb_name, epsilon=1.0, alpha=0.3): self.model = model self.emb_name = emb_name self.epsilon = epsilon self.alpha = alpha self.emb_backup = {} self.grad_backup = {} def attack(self, is_first_attack=False): for name, param in self.model.named_parameters(): if param.requires_grad and self.emb_name in name: if is_first_attack: self.emb_backup[name] = param.data.clone() norm = torch.norm(param.grad) if norm != 0: r_at = self.alpha * param.grad / norm param.data.add_(r_at) param.data = self.project(name, param.data, self.epsilon) def restore(self): for name, param in self.model.named_parameters(): if param.requires_grad and self.emb_name in name: assert name in self.emb_backup param.data = self.emb_backup[name] self.emb_backup = {} def project(self, param_name, param_data, epsilon): r = param_data - self.emb_backup[param_name] if torch.norm(r) > epsilon: r = epsilon * r / torch.norm(r) return self.emb_backup[param_name] + r def backup_grad(self): for name, param in self.model.named_parameters(): if param.requires_grad and param.grad is not None: self.grad_backup[name] = param.grad.clone() def restore_grad(self): for name, param in self.model.named_parameters(): if param.requires_grad and param.grad is not None: param.grad = self.grad_backup[name]训练方法:pgd = PGD(model, emb_name='word_embeddings', epsilon=1.0, alpha=0.3)K = 3for batch_input, batch_label in processor: loss = model(batch_input, batch_label) loss.backward() pgd.backup_grad() for t in range(K): pgd.attack(is_first_attack=(t == 0)) if t != K-1: model.zero_grad() else: pgd.restore_grad() loss_adv = model(batch_input, batch_label) loss_adv.backward() pgd.restore() optimizer.step() model.zero_grad()
FreeLB
FreeLB(Free-form Labeling with Bottleneck)是一种对抗学习的半监督训练方法,主要用于标签数据不足的场景。
代码示例:
class FreeLB: def __init__(self, adv_K=3, adv_lr=0.05, adv_init_mag=1000, adv_max_norm=0.0, adv_norm_type='l2'): self.adv_K = adv_K self.adv_lr = adv_lr self.adv_init_mag = adv_init_mag self.adv_max_norm = adv_max_norm self.adv_norm_type = adv_norm_type def attack(self, model, inputs, gradient_accumulation_steps=1): input_ids = inputs['input_ids'] if isinstance(model, torch.nn.DataParallel): embeds_init = getattr(model.module, self.base_model).embeddings.word_embeddings(input_ids) else: embeds_init = getattr(model, self.base_model).embeddings.word_embeddings(input_ids) if self.adv_init_mag > 0: input_mask = inputs['attention_mask'].to(embeds_init) input_lengths = torch.sum(input_mask, 1) if self.adv_norm_type == 'l2': delta = torch.zeros_like(embeds_init).uniform_(-1, 1) * input_mask.unsqueeze(2) dims = input_lengths * embeds_init.size(-1) mag = self.adv_init_mag / torch.sqrt(dims) delta = (delta * mag.view(-1, 1, 1)).detach() elif self.adv_norm_type == 'linf': delta = torch.zeros_like(embeds_init).uniform_(-self.adv_init_mag, self.adv_init_mag) * input_mask.unsqueeze(2) else: raise ValueError("Norm type {} not specified.".format(self.adv_norm_type)) for astep in range(self.adv_K): delta.requires_grad_() inputs['inputs_embeds'] = delta + embeds_init inputs['input_ids'] = None outputs = model(**inputs) loss, logits = outputs[:2] loss = loss.mean() loss = loss / gradient_accumulation_steps loss.backward() delta_grad = delta.grad.clone().detach() if self.adv_norm_type == 'l2': denorm = torch.norm(delta_grad.view(delta_grad.size(0), -1), dim=1).view(-1, 1, 1) denorm = torch.clamp(denorm, min=1e-8) delta = (delta + self.adv_lr * delta_grad / denorm).detach() if self.adv_max_norm > 0: delta_norm = torch.norm(delta.view(delta.size(0), -1).float(), p=2, dim=1).detach() exceed_mask = (delta_norm > self.adv_max_norm).to(embeds_init) reweights = (self.adv_max_norm / delta_norm * exceed_mask + (1 - exceed_mask)).view(-1, 1, 1) delta = (delta * reweights).detach() elif self.adv_norm_type == 'linf': denorm = torch.norm(delta_grad.view(delta_grad.size(0), -1), dim=1, p=float("inf")).view(-1, 1, 1) denorm = torch.clamp(denorm, min=1e-8) delta = (delta + self.adv_lr * delta_grad / denorm).detach() if self.adv_max_norm > 0: delta = torch.clamp(delta, -self.adv_max_norm, self.adv_max_norm).detach() else: raise ValueError("Norm type {} not specified.".format(self.adv_norm_type))
训练方法:
FreeLB 是一个比较灵活的对抗训练方法,具体如何调用需要参考实现代码及文档。
效果展示
经过一系列实验和实际应用可以发现,对抗学习方法确实能够带来一定的性能提升。特别是在模型对抗性稳定性的方面表现优异。
注意事项
参考资料
转载地址:http://vddzk.baihongyu.com/