DGI: Deep Graph Infomax 阅读笔记
论文来源:2019 ICLR
论文链接:Deep Graph Infomax
论文原作者:Petar Veličković, William Fedus, William L. Hamilton, Pietro Liò, Yoshua Bengio, R Devon Hjelm
代码链接:https://github.com/PetarV-/DGI
(侵删)
近年来,图神经网络领域(GNN)在蓬勃发展中,许多优秀的论文和想法不断地被提出。以卷积神经网络为主,一系列的图卷积论文占领了GNN的主流地位。目前来说,GNN的很多算法都是由NLP以及CV领域“迁移”过来,毕竟相对于这两个领域来说,网络领域属于“小众”范围。所以很多GNN相关论文更多地是在探索如何将其他领域最先进的算法应用在“非欧”数据上。
这篇论文的作者之前提出过基于attention机制的图卷积,GAT算法,算是GNN领域中走得比较前沿的人;这篇论文核心也是如何将最新的Deep Infomax算法“迁移”地应用到图领域中。笔记中可能有一些自己理解的不对的地方,欢迎交流。
Introduction
这篇文章是以非监督的方式去学习节点的嵌入向量。
如何为无监督学习设置loss是很重要的。如果是在一个最简单的自编码器中,那么loss很简单,就是输入与输出的重构误差(当然,也可以对隐藏层进行一些约束)。
而在对网络进行无监督地特征学习中,大多是基于“随机游走式的目标函数”的一些算法,比如Deepwalk, node2vev等,以及通过auto-encoder来重构邻接矩阵或特征矩阵的一些算法,比如sdne,VGAE,EP等。这些算法在设置目标函数时,目的都是想让在原本网络中相近的节点,在嵌入空间中,也是相近的。
那如何去度量在原本网络中,两个节点的相近程度呢?随机游走式目标函数的这些算法主要是根据邻接矩阵,学习出PPR(Personalized PageRank)值,从而求出每一个节点对之间的相近程度。这些算法主要会有两个限制:一是过于强调节点的邻近信息,忽略了节点的结构信息;二是算法对超参比较敏感。
目前为止,大多数GNN算法都是以卷积为主,如图1所示。
第k layer中,节点b的向量,会融合(k-1) layer中,它的邻居节点a,c,d,e的信息。GraphSAGE算法随后将图卷积学习到的节点向量
,与负采样形式的随机游走式目标函数结合,从而可以成功训练,loss函数为:
这样的做法,从loss函数角度上去看,编码器会使得邻接矩阵中相邻的节点拥有相似的表达向量。但从图卷积这个角度上去看,卷积的过程本来就将它的邻居信息融合了,在经过几次这样的patch-level的融合,相邻的节点就已经拥有相似的表达向量了,因此基于随机游走式目标函数可能就不能够再给训练过程提供有效的梯度指示了。
因此,需要寻找其他的有效目标函数。这篇文章的做法,则是受最近引起广泛关注的Deep Infomax(DIM)的启发,将目标函数设成最大化互信息,因此接下来大概介绍下DIM。
Deep InfoMax
DIM的原文没太看懂,专栏PaperWeekly的文章“深度学习中的互信息:无监督提取特征“对DIM有一个比较清晰的讲解,在这里根据这篇文章,就大概讲一下自己的理解。
首先,DIM认为,重构误差小,不能说明学习出来的特征好,好特征应该是能够提取出样本的最独特,具体的信息。那如何衡量学习出来的信息是该样本独特的呢?这里就是用“互信息”(Mutual Information,MI)来衡量。因此,DIM的一个核心idea是,训练一个编码器,它的目标函数,不是最小化输入与输出的MSE,而是最大化输入与输出的互信息。
互信息是概率论和信息论中重要的内容,它表示的是一个随机变量中包含另一个随机变量的信息量,可以理解成两个随机变量之间的相关程度。(X表示输入,Y表示编码器的输出,即学习出来的特征向量)
也会尽量变大,这就意味着p(y|x)会大于p(y),也就是说对于每个输入样本x,编码器能够尽可能地找出专属于样本x的特征y。因此,这样一来,只通过特征y,也能很好地分辨出原始样本来(因为学习到的特征含有样本的独特信息)。
那么如何最大化互信息呢?
将互信息稍加变换可以发现,互信息实际上就是变量x,y的联合分布与它们边缘分布的乘积的KL散度。也就是说,最大化互信息,就是要拉大联合分布与边缘分布乘积的距离。由于KL 散度理论上是无上界的,因此为了更有效地优化,可以利用JS散度与KL散度之间的转换关系:
JS散度是有上界的,它的上界为log2/2。此时就可以将最大化互信息这个问题转换成最大化JS散度。
现在的问题就变成了如何去有效的计算JS散度了? 这里要介绍下f-GAN。
f-GAN
f-GAN这篇论文是通过GAN来快速有效地对各种散度进行估算,一些详细的内容可以查看博客:f-GAN简介:GAN模型的生产车间。
在机器学习中,计算两个概率分布P,Q的散度是有一定难度的,因为很多时候是无法知道两个概率分布的解析形式,或者分布只有采样出来的样本(这时就是比较两批样本之间的相似性)。
散度的一般化形式可以写作:
f-GAN是通过“局部变分技巧“来进行快速地估算,具体的细节这里就不讨论。根据f-GAN,最终估算散度的公式为:
其中g为f函数的共轭。这个公式的意思就是,分别从两个分布进行采样,然后计算T(x)与g(T(x))的平均值,优化T,使得它们的差最大,最终的结果即为散度的估算值。T(x)可以用足够复杂的神经网络去拟合。
对于JS散度,它的函数形式f(u),共轭形式g(t),值域,以及最后T(x)所用的激活函数为:
因此,最终JS散度的估算公式(省略了常数项)为:
因此,最大化mutual information的目标函数为:
这个公式实际上就是“负采样估计”:引入一个判别网络 σ(D(x,y)),x 及其对应的 y 视为一个正样本对,x 及随机抽取的 y 则视为负样本,然后最大化似然函数(等价于最小化交叉熵)。
再返回到Deep InfoMax
上面简单的介绍了如何解决最大化互信息的方案,而DIM第二个核心的idea就是它是最大化局部特征与全局特征的互信息。因为对于图片,它的相关性更多体现在局部中,图片的识别、分类等应该是一个从局部到整体的过程。简单来说就是,全局特征更适合用于重构,局部特征更适合用于下游的分类任务。(局部特征就是卷积后得到的feature map;全局特征就是对feature map进行编码得到的feature vector)
因此最终DIM的算法流程为:
Deep Graph InfoMax
那现在的问题就是,如何将DIM算法应用到网络领域上。我们需要解决4个问题:
- 如何得到局部特征(patch representations);
- 如何得到全局特征(global summaries);
- 如何得到负样本的局部特征;
- 如何设计判别网络D,区分正负patch-summary pairs;
Patch representations :
与图片不同,对于网络,每一个节点的特征向量,即为该节点的局部特征。论文中采用的是图卷积核编码器来学习节点的向量,运用GCN算法,可以将周围邻居的信息整合起来,因此又可将节点向量称作patch representations:
Global Summaries :
全局特征通过readout函数
来获得。论文中说到,实验中尝试了几种readout函数(比如set2vec等等),发现效果最好且实现起来最简单的,就是对所有patch representations取平均。
Obtaining negative patch representations:
在DIM中,负样本是直接用另一张照片作为fake输入;然而在网络中,我们针对的是节点,因此数据点都是属于同一个graph的。因此,论文中是通过一个“腐蚀函数”
。
Discriminating positive and negative patch-summary pairs:
最后就是如何去区分正负“样本对”。令为正样本对,
为负样本对;之后类似DIM模型一样,通过一个判别器
(实际上是双线性二元分类器)来为“样本对”来打分。
最后,整个目标函数(最大化JS散度形式的互信息)为:
整个DGI的过程如图所示:
Experiments
实验部分这里就不详细展开了,代码链接已给出。
对于大规模网络来说,邻接矩阵太大了,是读不进内存的,因此论文中提到要做采样。比如,首先从所有节点中采样出一个minibatch,共256个节点。之后将每一个当作中心节点,围绕着它采样10个邻居,10个邻居的邻居,25个邻居的邻居的邻居,因此最终共有1+10+100+2500=2611个节点被采样出来,形成一个子图,然后根据这个子图去计算中心节点的patch representations。最后将256个节点的patch representations平均,得到summary vector 。
学习到节点的向量表达后,下游实验是用简单的线性逻辑回归进行分类任务。实验涉及到了直推学习和归纳学习。
Conclusion
这篇论文写作上还是比较清晰的,首先介绍了为什么对于图卷积网络来说,要寻找其他的目标函数;接着介绍论文是如何将DIM的最大化互信息思想应用到图领域中;最后通过理论(这里没太看懂)以及实验来证明最大化互信息的目标函数对节点学习向量表达是有帮助的;最后附录给出了一些鲁棒性分析的实验。实验上也比较详尽,具体实验设置也都详细说明,个人感觉,整篇文章是很“ICLR风格”的一篇论文。
但可能理解得不够透彻,总是感觉论文对公式的直观解释有点牵强,而且没有比较只有全局互信息和只有局部互信息时的实验。但总得来说,这篇文章给网络嵌入学习领域提供多了一个目标函数的设计,可能之后也可以往这方面多思考思考。
(引用的相关的论文和博客都已附上链接)