在最初听说知识蒸馏技术的时候,我是持怀疑态度的,甚至觉得不可思议,为什么通过用简单模型去学习复杂模型的效果会比直接用训练标签来训练简单模型要好???
但是,它的存在必有其合理性,更何况是我偶像,深度学习第一人Hinton等人最早开始提出这种思想的.
于是便带着疑惑,对所谓的模型蒸馏技术做了一番研究,发现这个东西确实有过人之处,能够用更简单的模型获得更优质的推理效果,这在工程上,简直是妙不可言.下面就让我们来think think,模型蒸馏为什么有用,又是怎么来实现的.
什么是知识蒸馏
众所周知,对于各类任务,当有足够多的数据的情况下,我们的神经网络模型越大越深,往往效果也会越好,正如ResNet50在图像任务上摧枯拉朽,Large Bert在语言任务上效果拔群,除了优秀的模型结构涉及,可以归结为是大力出奇迹.
但是,在实际的生产中,部署一个很大的推理模型是十分困难的,因为它的计算量是无数大大小小公司不能承受之痛,并不是每个企业都像Google那样拥有成千上万的TPU,当然即使有,在大部分场景下,也显然是不划算的.为了解决日益增长的模型预测效果的追求和和工程师想要提高性能老板想要节省成本之间的矛盾,有人提出了知识蒸馏技术.
即我们先在原始的训练数据上训练一个大的复杂的拟合的好泛化能力也很好的巨无霸模型(教师模型),再用这个复杂模型的inference结果取代原有的标签,用于训练一个新的更小的效果跟教师模型相差不大的模型(学生模型).然后生产环节只要部署这个性能强劲和推理效果足够好的学生模型就可以了.
好,这个想法实在是太好了..但是旁观者大概会有些不明觉厉....直接从原始的训练数据学不好吗?干嘛还多此一举去学一个更不精确的拟合结果数据?
这样做自然是有好处滴,且听我给你慢慢分析...这一切应该从一个软妹字说起..... [噗..抱歉,多打了一个妹字...
一切从soft特征开始说起
人类能够非常好的从许许多多的特征之中找到主要特征来区分不同的物品,而不会被表面很多相似的特征所迷惑,比如,人类可以较好的区分一只像猫的狗或是一只像狗的猫,而对于深度神经网络来说,却并没有那么容易.正如Hinton等人的一个经典论述: 一辆宝马被深度网络识别为一台垃圾车的可能性很小,但是被错误的识别为一个胡萝卜的可能性却要高很多倍.
为了让网络能够获得学习这些东西的能力,我们不得不让网络变得更深更复杂.知识蒸馏的目的就是希望大模型能够将学习到的这些区分近似特征的能力教给小模型,教育这种知识的精髓就恰好在于用softmax的软特征来取代原始one-hot标注的硬特征.
仔细想一下,软特征的好处实际上是显而易见的.
就拿手写数字识别的例子来说,我们的标注数据的不同分类之间,实际是无法捕捉到它们之间的关系的,因为它们都是只有自己的分类位置是0,其余位置是1,每个目标向量之间的距离是一样的,因此这种标注的方式实际上是存在一定缺陷的,它无法包含这样一种信息:比如数字1,和只带有一点点弯曲的7实际是极为相似的,但实际的标注数据并不能体现这一点.但是经过一个大模型的学习之后,或许对于一个只有一点点弯曲的7模型的预测结果中,1的score是0.4,7的score是0.5,其余score都接近0. 当我们看到这样一组特征向量的时候,是可以很清晰的发现这个手写图片非常相7同时又有点像1而和其他数字不像.
因此,再用这个向量作为target给小模型进行学习的时候,小模型只需要很小的代价就能学习到这一复杂的关系了~
softmax的扩展
是不是觉得我上面的说法很有道理? 如果你真的就这么认为,那就too naive了! 梦想很丰满,而现实却很骨感..真实的情况是,经过softmax函数之后,几乎不可能出现某个分类0.5,另一个分类0.4的情况,更一般的是某个分类0.99,另一个分类0.01......
当然,别担心,前面的想法这么好,自然遇到一点困难不该轻易放弃,既然softmax不行,那我们就不如就给它调整一下..
Hinton等大佬的解决方案是:将原始logits传递给softmax之前,将教师模型的原始logits按一定的温度进行缩放.这样,就会在可用的类标签上得到更加广泛的分布.并且这个温度缩放机制同样可以用于学生模型.
然后,原始的softmax操作就变成了:
其中, 便是一个缩放因子的超参数,这些得到的结果便是所谓的软目标...
变大,类别概率就会变软,也就是说会相互之间更加接近,从而达到了捕捉类别间关系的目的.
除了上述这种方法,还有其他人有一些别的不使用softmax获得软特征的方法,各有优劣...因为想快点写完这篇,所以别的方法先不介绍了,有兴趣可以自己了解,或者改天有时间我回来补充上这个部分....
使用传统机器学习方法
如果想要更大限度的压缩模型,可以使用一些十分高效的传统机器学习方法作为学生去蒸馏
比如决策树。我觉得这可能是一个很好的方法,尽管它们的表达能力不如神经网络,但它们的预测非常可控和具有解释性,并有可能实现自动的更新和快速迭代.可以看一下Hinton他们的研究,读下这篇论文Distilling a Neural Network Into a Soft Decision Tree
他们的研究表明,尽管更简单的神经网络的表现比他们的研究要好,但蒸馏确实起到了一点作用。在MNIST数据集上,经过蒸馏的决策树模型的测试准确率达到96.76%,较基线模型的94.34%有所提高。然而,一个简单的两层深卷积网络仍然达到了99.21%的准确率。因此,在任务追求的精度和推理性能及边界性之间寻求一个权衡即可。
总结
个人认为知识蒸馏是一个极具前途的研究.它让更好的推理效果以更小更便捷的方式得以部署,这在工业界简直是无敌的存在.正所谓名师出高徒,和人类的学习一样,能够有一个牛逼的老师对你进行深入浅出的指导,能让你的学习过程事半功倍.而知识蒸馏,正好就是实现了这样一个深入浅出的功能,这种思想我个人十分推崇.