WWW2020
文章简介:本文旨在利用用户-商品交互数据来增强知识库补全。使用了基于图神经网络的学习算法和生成对抗网络来融合交互信息和知识库数据。
原文
代码
文章最后会介绍代码结构。
目录:
1 INTRODUCTION
2 RELATED WORK
2.1 Knowledge Graph Completion
2.2 Collaborative Recommendation and KGC Models
2.3 Generative Adversarial Networks
3 PRELIMINARY
3.1 Knowledge Graph Completion(KGC)
3.2 User Interaction
3.3 Interaction-Augmented Knowledge Graph
3.4 Task Description
4 THE PROPOSED APPROACH
4.1 Overview
4.2 Collaborative Representation Learning over Interaction Augmented KG
4.2.1 Learning Entity-oriented User Preference
4.2.2 Learning Preference-enhanced Entity Representation
4.2.3 Discussion
4.3 User Preference Guided Discriminator
4.3.1 Discriminator Formulation
4.3.2 Discriminator Loss
4.4 Query-specific Entity Generator
4.4.1 Generator Formulation
4.4.2 Policy Gradient
4.5 Optimization and Discussion
5 EXPERIMENT
5.1 Dataset Construction
5.2 Experimental Setting
5.2.1 Evaluation Protocol
5.2.2 Methods to Compare
5.2.3 Implementation Details
5.3 Results and Analysis
5.4 Detailed Analysis of Performance Improvement
5.4.1 Performance Comparison w.r.t. Sparsity Levels
5.4.2 Performance Comparison w.r.t. Hop Number
5.4.3 Ablation Study
5.5 Performance Sensitivity Analysis
5.5.1 Varying the amount of KG triples
5.5.2 Varying the amount of user interaction data
5.6 Case Study
6 CONCLUSION
1 INTRODUCTION
大多数的KGC(Knowledge Graph Completion,知识库补全等同于KBC)方法都是设计一个新的学习算法来从知识图谱中获得表征。本文提出的切入点是利用实际应用获得的数据(本文表现为用户-商品交互数据)来加强KGC。这个切入带你的灵感是因为在线应用系统和知识库的实体有关联,比如说推荐系统的数据集MOVIE-LENS知识库Freebase在电影实体上有大面积重叠。用户交互数据也的确包含了很多有潜在价值的信息,原文中Fig1图解的很清楚这里就不搬运了。
之前也有结合用户交互信息来做KGC和推荐系统的工作。包括:
基于路径的方法(Recurrent knowledge graph embedding for effective recommendation.ACM2018),
基于正则化的方法( Transfer Learning for Item Recommendations and Knowledge Graph Completion in Item Related Domains via a Co-Factorization Model.ESWC2018),
基于图神经网络的方法( KGAT:Knowledge Graph Attention Network for Recommendation.KDD2019)。
这些方法的共同点是开发数据融合模型,然后在相同空间学习两种数据(数据库数据,用户交互数据)的表征。作者认为上述这些方法不好(hurt the original representation performance using simple fusion strategy)。所以本文提出了UPGAN((User Preference enhanced GAN)模型,这个方法好(Such an approach is effective to alleviate the issues about data heterogeneity and semantic complexity)。
2 RELATED WORK
2.1 Knowledge Graph Completion
TransE,GNN等技术加持的各种模型,可以学到知识图谱的图结构信息,但是没法融合用户交互数据。
2.2 Collaborative Recommendation and KGC Models
最近有几项研究旨在开发推荐系统和知识库补全的复合模型。比如:
因子分解(Transfer Learning for Item Recommendations and Knowledge Graph Completion in Item Related Domains via a Co-Factorization Model.ESWC2018),
关系迁移(Unifying Knowledge Graph Learning and Recommendation: Towards a Better Understanding of User Preference.WWW2019),
多任务学习( Multi-Task Feature Learning for Knowledge Graph Enhanced Recommendation.WWW2019),
图神经网络(KGAT: Knowledge Graph Attention Network for Recommendation.KDD2019)。
这些研究将不同数据表征在一个空间,同时采用了两套目标函数,来完成推荐系统和知识库补全两种任务。本文只考虑KGC这一种任务。
2.3 Generative Adversarial Networks
最近的研究将GAN用于图数据( GraphGAN: Graph Representation Learning With Generative Adversarial Nets.AAAI2018)和异构信息网络(Adversarial Learning on Heterogeneous Information Networks.KDD2019)上。这些方法用于节点分类等任务,无法用于我们的场景。也用于知识库补全(KBGAN: Adversarial Learning for
Knowledge Graph Embeddings.NAACK-HLT2018/ Incorporating GAN for
Negative Sampling in Knowledge Representation Learning.AAAI2018),这类方法旨在生成高质量负样本。
本文设计的模型,在D中融合用户交互数据,G中生产负样本来加强D。
3 PRELIMINARY
3.1 Knowledge Graph Completion(KGC)
KGC模型常见的套路是将实体和关系映射到低维空间,然后设计一个得分函数来计算输入三元组的可信度。
3.2 User Interaction
将用户和商品的交互也定义为三元组。一般交互可以包含很多种关系比如购买,点击。这里方便起见交互为一种关系。
3.3 Interaction-Augmented Knowledge Graph
通过交互信息中商品和知识库中实体的对齐,将交互信息和知识库合成一张图谱,节点有用户节点u,商品节点i,实体节点e三种。设计一个参数dn代表节点n和用户节点间的最小跳数(minimum hop number)。对于节点n,dn=0代表为用户节点u,dn=1代表为商品节点i,dn>1代表实体节点e。
3.4 Task Description
等同于链路预测(link prediction)预测三元组中缺失的实体。
4 THE PROPOSED APPROACH
UPGAN介绍。
4.1 Overview
因为两种数据的异构性,所以使用GAN来完成任务。G负责生成链路预测任务中需要预测的实体,D来判别生产实体的质量。G用来帮助D融合两种数据的能力。本文的任务可以由下列公式表现:
a是G生成实体,模型总体架构见下图:
4.2 Collaborative Representation Learning over Interaction Augmented KG
对于表征学习,简单的方法是所有节点一视同仁进行嵌入。然而节点存在异构性,所以本文基于图神经网络设计了一个两阶段的学习算法。
4.2.1 Learning Entity-oriented User Preference
这一步是从实体节点e向用户阶段u的传播(propogation),从图2中看就是Layer2或者Layer3往Layer0方向进行训练。公式2里第二项里的三元组(nj,r,nk)中dnj=dnk-1,即j为上一层节点k为下一层节点,从k往j(向图2中layer0)的方向传播。作者认为这样做的好处是减少了交互信息的噪音,同时考虑了路径的嵌入。作者这里还cue了下PTransE(Modeling Relation Paths for Representation Learning of Knowledge Bases,往期有笔记)。
4.2.2 Learning Preference-enhanced Entity Representation
和4.2.1反方向的训练,不再赘述了。同时图2里也解释的蛮清楚了。这里用到了图注意力机制。
4.2.3 Discussion
这里作者总结了一下4.2.1和4.2.3。注意本文的公式中,n表示所有节点(包括用户u,商品i,实体e)。按照我的理解4.2.1中的v表示了从右到左学习得到的嵌入,4.2.2中得p表示了从左到右学习得到的嵌入。所以对于同一个节点h,在这一步会存在两种不同的嵌入如公式8中。这样做的另一个目的就是可以保证G中的嵌入没有包含交互信息。
4.3 User Preference Guided Discriminator
与2.3中讨论的GAN相比,UPGAN的特色是将4.2中学习到带交互信息的嵌入应用到D中。
4.3.1 Discriminator Formulation
对于三元组(h,r,t),D输入h和r可以得到实体t的概率分布(公式6)。三元组(h,r,t)的可信度/得分函数为公式7。公式7中xq的来历见公式8。
4.3.2 Discriminator Loss
训练过程中最小化的损失函数。注意4.2的表征学习也是通过这个损失函数训练的。
4.4 Query-specific Entity Generator
G中不会用到用户交互信息因为作者认为其中噪音很大。
4.4.1 Generator Formulation
对于输入h和r,公式12产生了采样概率分布。C是所有候选实体集合,a是最后采样的实体组成负样例(h,r,a)。
4.4.2 Policy Gradient
文章沿用KBGAN使用policy gradient进行训练。Reward见公式13,损失函数公式14。
4.5 Optimization and Discussion
UPGAN首先对D进行预训练,然后按照标准的GAN训练过程,根据公式12进行采样,根据公式9和公式14进行训练。需要注意的是嵌入层的参数是和D一起训练的。
本模型可能是第一个用GAN进行知识库补全任务中使用到用户交互信息的模型。
5 EXPERIMENT
5.1 Dataset Construction
经过对KB4Rec数据集的一系列预处理后,得到的数据集如表1。
构造数据集时用到BFS4跳遍历实体(我的理解理论上最多layer6,初始实体在layer2)
5.2 Experimental Setting
5.2.1 Evaluation Protocol
采用链路预测(link prediction),指标有Mean Rank,top-k hit ratio,Mean Reciprocal Rank。
5.2.2 Methods to Compare
直接上图吧:
5.2.3 Implementation Details
一些超参数,需要注意的是训练使用DistMult初始化嵌入。G生成负样例一次1024个,然后从中采样200个。
5.3 Results and Analysis
见表3.
5.4 Detailed Analysis of Performance Improvement
5.4.1 Performance Comparison w.r.t. Sparsity Levels
将测试集按照答案实体的出现频率分成五组,来验证稀疏数据集上表现。
5.4.2 Performance Comparison w.r.t. Hop Number
测试不同hop数实体的表现(我的理解是layer2算1-hop,layer3算2-hop,以此类推)
5.4.3 Ablation Study
UPGAN以及3个阉割版本。第一行为去掉G普通负采样,第二个为去掉两阶段学习使用R-GCN,第三个为去掉两阶段学习使用GAT(没有用到用户交互信息)。可以看出用户交互信息还是对结果很有帮助。
5.5 Performance Sensitivity Analysis
5.5.1 Varying the amount of KG triples
采用不同训练集大小考察模型的表现。
5.5.2 Varying the amount of user interaction data
采用不同用户交互信息数量考察模型的表现。
5.6 Case Study
演示了如何从用户交互信息来完成知识库补全。
6 CONCLUSION
未来工作是推广到更多的数据集上。
代码结构(没有用到的文件和一些辅助文件,这里就不囊括进去了):
checkpoint/
data/
-
Model/
base_model.py
UGAT_mlp.py
layers.py
generator_contact.py
-
pretrain/
trainer.py
base_trainer.py
init.py
-
train/
trainer.py
base_trainer.py
init.py
load_data.py
evaluation.py
util/
main_pretrain.py
main_upgan.py
接下来是详细介绍:
-
checkpoint/
主要存放训练中产生的ckpt文件 -
data/
存放训练数据 -
Model/
存放模型。base_model.py
父类为torch.nn.Module,UGAT_mlp.py中模型的父类。get_user_all方法获取了DistMult的嵌入,随后对应公式3(这里有些奇怪,因为没有调用到gnn_layer_1.forward。get_user_all中使用self.gnn_layer_1的过程很奇怪)。fetch_user_batch方法调用了layers.py进行公式4。UGAT_mlp.py
父类在base_model.py中。query_layer方法调用了layers.py中的GAT。form_query方法对应文章公式7和8。注意这里UGAT类有forward方法但是并没有被用到过(代码中没有self.model.forward,只有self.query_layer.forward_rs这种。另外module_def只是给你看看结构,并没有按照这个顺序执行)layers.py
UGAT_mlp.py调用,包含了GCN和GAT的类。方法forward_rs和forward_kg对应公式4。generator_contact.py
G的模型。从代码中可以看到是采用DistMult加了一层全连接和激活。值得注意的是全连接层输入是embedding_size * 2,这是因为引入了噪声可以在forward_triple方法中看到(对应公式10,11,12)。这个方法被trainer.py中gen_step调用。
-
pretrain/
入口文件main_pretrain.py调用。-
trainer.py
Trainer_new类,包含模型训练的方法,超参数的初始化。dis_step方法对应公式6和9。负样本在构造的数据集里已经制作好(无需负采样)。 -
base_trainer.py
trainer.py中的Trainer_new的父类,包含模型的载入,归一化,路径采样等方法.show_norm方法对应日志中每个epoch后显示嵌入的范数信息。load_pretrain方法加载ckpt文件。 -
init.py
返回Model文件夹中的模型。trainer.py调用,生成trainer.model。默认使用的是UGAT_mlp.py
-
-
train/
入口文件main_upgan.py调用。-
trainer.py
和pretrain文件夹中trainer.py比,train_epoch_kg方法多了G的部分。dis_step输入的负样本来自方法gen_step,这个方法还包括了GAN的训练。gen_step通过generator_contact.py中获得的score采用torch.multinomial进行采样。 -
base_trainer.py
和pretrain文件夹中相比未发现重要区别。 -
init.py
和pretrain文件夹中init.py比,多了调用G模型的部分。默认使用的是generator_contact.py。 -
load_data.py
加载数据集,base_trainer.py(两个文件夹下的皆)调用。 -
evaluation.py
评估模型的方法,base_trainer.py(两个文件夹下的皆)调用。eva_rank_list方法计算mrr等,evaluate方法调用Model下文件里各个模型(self.model.evaluate)。
-
-
util/
一些训练日志,数据集存储,以及支持训练的辅助文件,就不详细介绍了。 -
main_pretrain.py
预训练入口文件,与main_upgan.py相比不包含G的部分。可以加载预训练的知识库嵌入,对应参数--load_ckpt_file,原文使用的是DistMult。这里值得注意的是知识库的嵌入,对于模型来说其实算是嵌入层的参数(作为模型参数的一部分,而且像DistMult这样Semantic match模型除了嵌入外没有额外参数)。训练完后在--checkpoint_dir下生成名为--experiment_name的ckpt文件(模型参数,见base_trainer.py的save_ckpt方法),log文件(训练日志),和middle文件(由evaluation.py里的write_res方法写入,尚不清楚有何作用) -
main_upgan.py
正式训练入口文件。参数--load_ckpt_file加载运行main_pretrain.py生成的experiment_name.ckpt文件,加载至self.model。参数--load_ckpt_G加载DistMult模型至self.G。另外可以从代码看到G用的就是DistMult加了一层全连接和激活。这里的self.model应该就是代表D了,self.D只出现在参数中一切D完成的功能都是由self.model来做。训练完生成文件和main_pretrain.py一致。