这是一种使用GAN来生成对抗样本的模型
代码:
首先来看一个训练过程
代码中首先训练的是D
首先用generator生成干扰项 perturbation,然后与原图相加形成对抗样本 adv_images
当然训练一个D的loss分为了两部分,loss_D_real旨在拉近吃正样本之后的输出与1的距离
loss_D_fake旨在拉近吃负样本之后与0的距离,这里的负样本就是对抗样本,输入的时候不要忘了detach掉
# optimize D
# x are the input images
for i in range(1):
perturbation = self.netG(x) # torch.Size([128, 1, 28, 28])
# add a clipping trick
adv_images = torch.clamp(perturbation, -0.3, 0.3) + x
adv_images = torch.clamp(adv_images, self.box_min, self.box_max)
self.optimizer_D.zero_grad()
pred_real = self.netDisc(x)
loss_D_real = F.mse_loss(pred_real, torch.ones_like(pred_real, device=self.device))
loss_D_real.backward()
pred_fake = self.netDisc(adv_images.detach())
loss_D_fake = F.mse_loss(pred_fake, torch.zeros_like(pred_fake, device=self.device))
loss_D_fake.backward()
loss_D_GAN = loss_D_fake + loss_D_real
self.optimizer_D.step()
训练G的过程就有些复杂了
首先要G的Gan损失的训练目标是让自己生成的对抗样本,在D看起来和正样本1相近
下方的retain_graph = True的意思是保留当前方向传播的计算图,可以做梯度累加
可以参见这两篇博客https://www.cnblogs.com/picassooo/p/13748618.html
https://www.cnblogs.com/picassooo/p/13818952.html
self.optimizer_G.zero_grad()
# cal G's loss in GAN
pred_fake = self.netDisc(adv_images)
loss_G_fake = F.mse_loss(pred_fake, torch.ones_like(pred_fake, device=self.device))
loss_G_fake.backward(retain_graph=True)
接下来就是限制扰动大小的损失
这里设计的是一个batch之中所有图片的矩阵二范数都不能太大
# calculate perturbation norm
C = 0.1
loss_perturb = torch.mean(torch.norm(perturbation.view(perturbation.shape[0], -1), 2, dim=1))
接下来就是样本对抗损失
onehot_labels这里的实现是比较优雅的,总体功能是根据手写数字的类别转换为onehot编码的格式,torch.eye的功能就是得到onehot编码,然后使用lables变量中对应的类别把他提取出来
real的功能,按照我粗浅的理解,是得到网络针对一个batch中所有对抗样本预测正确的概率。other的功能,是得到了网络针对一个batch中的所有对抗样本预测为错误的类别中,可能性最大的概率。
那个torch.max(real-other,0)的功能,按照我粗浅的理解,首先看real-ohter的部分,因为损失函数都是梯度下降的,最小化这个损失函数,相当于训练模型让real更小,other更大,犯错的概率越大。之所以要与0相max,也许是小于0的时候,other已经大于real了,然后没必要训练这个部分了?
最后一个损失函数我可能理解的不正确,还是要看一下那个C&W模型是怎么设计的
# cal adv loss
logits_model = self.model(adv_images)
probs_model = F.softmax(logits_model, dim=1)
onehot_labels = torch.eye(self.model_num_labels, device=self.device)[labels] # torch.Size([128, 10])
# C&W loss function
real = torch.sum(onehot_labels * probs_model, dim=1) # [128]
other, _ = torch.max((1 - onehot_labels) * probs_model - onehot_labels * 10000, dim=1)
zeros = torch.zeros_like(other)
loss_adv = torch.max(real -
other, zeros)
loss_adv = torch.sum(loss_adv)
接下来就是把这两个loss乘以一个超参权重,然后backward就好了
————————————————
版权声明:本文为CSDN博主「Blue_Whale2020」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/Blue_Whale2020/article/details/124051810