InfoGAN:通过信息最大化的生成对抗网络进行的可解释表示的学习
摘要:
这篇论文描述了InfoGAN,一种对于对抗生成网络的信息理论上的扩展,它能够以完全无监督的方式学习分离的表达。InfoGAN是一个对抗生成网络,它也会最大化隐藏变量的一个小的子集和观察数据之间的互信息。我们推出了可以被高效优化的互信息目标函数的下界。特别地说,InfoGAN成功了从MNIST数据集的数字形状中分离出了书写风格,从3D渲染图片的光照中分离出了姿态,以及SVHN数据集的中央数字中分离出了背景数字。它也从CelebA人脸数据集中发现了一些包络发型,是否戴眼镜和表情等视觉概念。实验表明,InfoGAN学习到了可解释的表达,这些表达比现有的监督方法学习到的表达更有竞争力。
本文提出的GAN的架构如下图所示,生成器G的输入不仅仅是噪声z,而是增加了一个隐含变量的c,这个隐含变量在无监督学习中,并不明确其具体指定的含义,但是就是需要分离度语义信息。辨别器D其实存在两个,一个依旧是分别数据真伪结果,一个给出的是条件概率分布Q(c|x),但是这两个辨别器(D/Q)共用前面所有的卷积层,只是最后分别用不同的全连接层得到最后的输出结果。
这样的网络设计,在原有的GAN训练的loss函数中,加入了一个互信息项的loss,以鼓励生成器G在生成数据的时候,不仅仅使用噪声z,同时也利用隐藏变量c。设计过程中,是要求c与生成的数据的互信息形成新的loss项,但是需要计算条件分布存在困难,于是进行的数学换算过程,然后加入Q的辅助分布辅助这一计算,因此加入的loss项如下:
在LI(G,Q)中,将H(c)看做常数项,那么优化过程中,只要优化公式5中的前面一部分。对于离散变量c,文中指出其概率是将Q(c|x)通过softmax计算获得的结果(softmax一般用于概率的计算,然后Q(c|x)输出被认为本身就带有其概率,因而用softmax做概率的归一处理);对于连续的变量c,文中指出采用高斯分布来计算其概率,也就是说直接使用正态分布的概率密度函数直接计算某数值的概率即可。最终整个模型的训练的loss如下:
这篇文章中采用的互信息的内容,来保证输入的noise和产生的图片之间的联系,并且使用由此设计的Loss来强约束两者之间的关系,最终确保了模型学到了从noise中学到对应的语义信息来生成图片。从某种意义上而言,除了互信息的约束,可以尝试其他的约束来绑定noise和生成的图片之间的联系,从而设计新的loss,可以得到新的GAN的模型用来生成与noise息息相关的图片。
代码有两个版本的实现:
openAI的工业化程度比较高的版本,个人觉得这个代码比较晦涩难懂,同时有些地方可能有点问题:
https://github.com/openai/InfoGAN
比较易懂的代码,Tensorflow版本实现的,推荐先看这个版本,再去看上面那个版本会轻松很多: