概述
对抗生成网络包含两个模型,一个是生成模型(generative model),一个是判别模型(discriminative model)。生成模型的任务是生成看起来自然真实的、和原始数据相似的实例。判别模型的任务是判断给定的实例看起来是自然真实的还是人为伪造的(真实实例来源于数据集,伪造实例来源于生成模型)。
论文中的类比解释:生成模型像“一个造假团伙,试图生产和使用假币”,而判别模型像“检测假币的警察”。生成器(generator)试图欺骗判别器(discriminator),判别器则努力不被生成器欺骗。模型经过交替优化训练,两种模型都能得到提升,但最终我们要得到的是效果提升到很高很好的生成模型(造假团伙),这个生成模型(造假团伙)所生成的产品能达到真假难分的地步。
GAN网络整体示意如下:
有两个网络,(Generator)和(Discriminator)。Generator是一个生成图片的网络,它接收一个随机的噪声,通过这个噪声生成图片,记做。Discriminator是一个判别网络,判别一张图片是不是“真实的”。它的输入是,代表一张图片,输出代表为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。
上面的 判别式模型拟合的是条件类别概率为,生成模型拟合的是联合概率分布
模型优化训练
再正式了解GAN模型优化训练过程前,先了解一下纳什均衡,纳什均衡是博弈中的局面:对于每个参与者来说,只要其他人不改变策略,他就无法改善自己的状况。转换到GAN中,生成模型恢复了训练数据的分布,判别器就无法再判别出正确的结果,即判别器开始乱猜。这是双方利益达到了最大化,不再改变策略,也就不再更新自己的权重了。
GAN的目标函数为:
上面的公式中, 为随机噪声;随机噪声服从的概率分布;生成器为输入噪声输出的假图像;为生成器生成的假图像服从的概率分布;为真实数据服从的概率分布;判别器输入的是图像,输出该图像来自的概率;为生成器的参数,为判别器的参数。公式的第一项是的期望,希望输入真实数据时,输出越大越好。公式中的第二项是的期望,希望输入假数据时,输出越小越好。
因此对抗体现在判别器希望能够正确的判别出图像是否来自于,即生成图像的判别结果趋近于0,真实图像的判别结果趋近于1;生成器希望输出的图像分布和真实分布接近,即生成图像的判别结果很小。
在训练生成器时损失的图像为下图,我们希望判别器对生成器生成的图像判别为正确的,因此希望生成器的损失越小越好。
那么他在训练的是时候是怎么训练的呢?从下面的伪代码中,我们可以发现在训练的过程中,首先固定生成器参数并训练多轮判别器,然后固定判别器参数并训练一轮生成器参数,由此交替迭代,使得对方的错误最大化,最终生成器能估测出样本数据的分布。
论文最后给出了对抗训练的过程图:
训练相关理论
命题一:当生成器参数固定时,最优的判别器为
它接收一个任意一个图片,输出最有判别器认为他是真实数据的概率。上式中,的结果不一定为1。
证明:根据期望的定义,可以由上文中GAN的目标函数转化为:
上式中积分变量不一致,利用测度论Radon-Nikodym Theorem(无意识统计学家定律)来切换积分变量,转化为:
无意识统计学家定律是用于切换积分变量:
当单射时,对噪声采样相当于对图像采样,即
上式的形式为可以看作为交叉熵,时在上取最大值。
对于上面的为常数,为自变量。交叉熵为,对其求导为,导数为0时。若且,在的结果为。根据上面的一阶导数为0、二阶导数为负,取得最大值。
定理1:当且仅当时,生成器价值函数取得全局最小值。
证明:在上文判别器已取到最优化时,训练器使判别器的价值函数最小化:
对上面式子做如下变换:
KL散度的定义为:
KL散度存在非负性和不对称性。JS散度的定义:设,则:
因此当时,得到全局最小值。
代码链接:github
参考文献: