论文标题:《DeepGBM: A Deep Learning Framework Distilled by GBDT for Online Prediction Tasks》
论文地址:https://www.microsoft.com/en-us/research/publication/deepgbm-a-deep-learning-framework-distilled-by-gbdt-for-online-prediction-tasks/
背景
目前两种主流的算法在在线学习推荐系统中被广泛的应用,包括以GBDT为主的tree-model,以及NN模型(wide and deep、deepfm等)。但是目前这两种学习方法都有各自无法解决的困难。
首先GBDT在处理密集的连续型数值特征中表现出非常好的学习性能和效果,但是GBDT不能进行在线学习也不能很好的处理大量的稀疏特征,因为GBDT通常需要训练全量的数据而不能进行局部的更新,而且大量的稀疏特征会带来过拟合和空间爆炸的风险。,
其次NN能进行在线学习做到按batch进行更新,也能通过embeding的手段很好的处理大量的稀疏特征,但是在处理连续性特征的能力上却远远不如GBDT。
尽管已经有很多研究将这两种方法整合,尝试去融合两者的优点,但是却仍然达不到。
相关工作
就像我们之前提到的,我们现在将GBDT和NN广泛的应用于在线推荐系统中。下面我们首先回顾一下现在很多将GBDT和NN结合起来去解决他们各自缺点,融合他们优点的一些相关工作。并且基于之前的经验去构建一个有效的在线学习的模型。
我们之前提到过,GBDT用于在线学习主要是会存在两个问题,一个是tree model的不可拆分性,导致tree model不能局部更新,而是需要经过大量的离线数据的训练。第二个是GBDT不能非常有效的处理稀疏型的类别特征。基于这两个缺点,我们先回顾一下大家为了解决这个问题所做的一些工作。
1.有一部分研究尝试用流式的数据来训练tree model,但是这种方式只能用在单个tree model和并行的tree model适应,对于gbdt这种boosting的模型就不太适用了。而且这种模型抛弃掉了历史的数据、只用更新的数据会有偏。
2.对于gbdt不能处理稀疏的类别特征,因为不平衡的稀疏特征的分布的信息增益非常小。很多的研究尝试将稀疏的类别特征编码成连续性特征,但是这些编码方式会损失掉很多信息。还有一些方法会直接枚举左右的二进制分区,但是由于我们的特征稀疏,只有少量的数据是有值的,这样会导致过拟合。
3.还有很多方式将GBDT和NN结合,如Deepforest和mGBDT等,但是他们都不能解决这两个关键的问题。
对于NN的改造,现在很多的NN的方式都集中于解决大量的稀疏特征,而对于连续性的特征并没有很好的解决,FCNN为了解决这种办法才去全链接的方式来解决稠密连续特征。但是FCNN的性能不能满足要求。因为全连接会带来复杂的网络结构使得超参数的学习变得非常的复杂而且容易陷入局部最优解。另外的方式是直接把稠密的数值特征离散化,但是离散化带来的非线性特征使得模型变得复杂而且容易过拟合。
对于NN+gbdt组合的改造,主要有三类:
1.Tree-like NN
2.Convert Trees to NN
3.Combining NN and GBDT:直接将GBDT和NN组合,facebook直接将GBDT的叶子结点组合作为NN的类别特征喂给NN。Microsoft使用GBDT去训练NN的残差,但是这些都没有解决模型的线上问题
deepGBM结构
本文提出的deepGBM主要包含两部分:CatNN和GBDT2NN,结构示意图如下:
GBDT2NN for Dense Numerical Features
目前很多的研究方法只将GBDT的输出作为输入放入NN中去学习,这样会损失tree model的结构和很多的信息。本文我们用NN去模拟GBDT的结构。尽可能多的获取GBDT的结构信息。
1.特征选择。对于单颗树的蒸馏,我们可以将GBDT使用的特征的index输入到NN中,而不是将所有的特征都输入到NN中去。我们知道NN的模型是对所有的特征进行学习,没有特征选择的过程,而根据GBDT的原理,tree model的产生并不依赖于所有的特征,而是只选择部分的特征来做树的分裂和抉择。所以NN直接将用GBDT用到的特征来构建GBDT的结构。我们约定It 作为树t选择的特征。所以我们选择用x[It]作为NN模型的输入。
2.树结构。 根据树模型的特征会将不同label的样本分到不同的结点中,而同一结点的label是一样的。所以我们使用以下公式来近似我们的这种树结构:
其中n是样本的数量,N()表示NN的网络结构,0表示网络的参数, Lt,i是当前结点输出的叶子结点的onehot的表示形式。L是损失函数,如logloss等。
3.树模型的输出。通过上述的学习,我们只需要知道我们的树结构(叶子结点index)到模型输出的映射关系就可以了。这种index到value的相关性的表示为pt = lt * qt,qt表示当前树的values。所以最终模型的输出可以表示为:
以上是我们对单颗树的蒸馏,然后GBDT作为一种boosting的算法,经常是有很多的树组成的,那么对于多棵树的蒸馏又是怎样的呢?
对于多棵树的蒸馏我们主要提出了两种方法。首先为了减少树结构的维度,我们可以将叶子结点由原先的onehot方式变为映射到一个低维的embedding空间中。
其中wt和w0都是trainable的参数,pt,i是leaf index对应的value值。Ht,i是我们输出的embedding值。L是和树训练时一样的loss函数。
所以现在模型的学习目标变为
其次我们将一定数量的tree进行组合,让他们共享embedding的参数,这样能减少模型训练的复杂度。组合的方式有很多种,你可以选择模型结构高度相似的trees进行组合,或随机组合。本文选择随机组合,将m颗树分为k组,那么每个组合包含m/k树。所以现在的embedding损失变为:
而模型的训练目标变为:
模型最后的输出为: