DVT:华为提出动态级联Vision Transformer,性能杠杠的 | NeurIPS 2021

论文主要处理Vision Transformer中的性能问题,采用推理速度不同的级联模型进行速度优化,搭配层级间的特征复用和自注意力关系复用来提升准确率。从实验结果来看,性能提升不错

来源:晓飞的算法工程笔记 公众号

论文: Not All Images are Worth 16x16 Words: Dynamic Transformers for Efficient Image Recognition

[图片上传失败...(image-eaf4e4-1663908579693)]

Introduction


  Transformers是自然语言处理 (NLP) 中占主导地位的自注意的模型,最近很多研究将其成功适配到图像识别任务。这类模型不仅在ImageNet上取得了SOTA,而且性能还能随着数据集规模的增长而不断增长。这类模型一般都先将图像拆分为固定数量的图像块,然后转换为1D token作为输入,拆分更多的token有助于提高预测的准确性,但也会带来巨额的计算成本(与token数成二次增长)。为了权衡性能和准确率,现有的这类模型都采用14x14或16x16的token数量。

[图片上传失败...(image-25ad6a-1663908579693)]

  论文认为不同图片之间存在相当大的差异,使用相同数量的token处理所有图片并不是最优的。最理想的做法应为每个输入专门配置token数量,这也是模型计算效率的关键。以T2T-ViT-12为例,官方推荐的14x14 token数仅比4x4 token数增加了15.9%(76.7% 对 60.8%)的准确率,却增加了8.5倍的计算成本(1.78G 对 0.21G)。也就是说,对“简单”图片使用14x14 token数配置浪费了大量计算资源,使用4x4 token数配置就足够了。

[图片上传失败...(image-67dcd-1663908579693)]

  受此启发,论文提出了一种动态Vision Transformer(DVT)框架,能够根据每个图片自动配置合适的token数,实现高效计算。训练时使用逐渐增多的token数训练级联Transformer,测试时从较少的token数开始依次推理,得到置信度足够的预测即终止推理过程。通过自动调整token数,“简单”样本和“困难”样本的计算消耗将会不一样,从而显着提高效率。

  另外,论文还设计了基于特征和基于关系的两种复用机制,减少冗余的计算。前者允许下游模型在先前提取的深度特征上进行训练,而后者允许利用上游模型中的自注意力关系来学习更准确的注意力图。

  DVT是一个通用框架,可集成到大多数图像识别的Transformer模型中。而且可以通过简单地调整提前终止标准,在线调整整体计算成本,适用于计算资源动态波动或需要以最小功耗来实现特定性能的情况。从ImageNet和CIFAR的实验结果来看,在精度相同的情况下,DVT能将T2T-ViT的计算成本降低1.6-3.6倍,而在NVIDIA 2080Ti上的真实推理速度也与理论结果一致。

Dynamic Vision Transformer


Overview

[图片上传失败...(image-aa2881-1663908579693)]

  • Inference

  DVT的推理过程如图2所示。对于每张测试图片,先使用少量1D token序列对其进行粗略表示,可通过直接使用分割图像块或利用如tokens-to-token模块之类的技术来实现,然后通过Vision Transformer对这些token进行快速预测。由于Transformer的计算消耗与token数量成二次增长,所以这个过程很快。最后基于预设的终止标准对预测结果进行快速评估,确定是否足够可靠。

  如果预测未能满足终止标准,原始输入图像将被拆分为更多token,再进行更准确、计算成本更高的推理。每个token embedding的维度保持不变,只增加token数量,从而实现更细粒度的表示。此时推理使用的Vision Transformer与上一级具有相同架构,但参数是不同的。根据设计,此阶段在某些“困难”测试图片上权衡计算量以获得更高的准确性。为了提高效率,新模型可以复用之前学习的特征和关系。在获得新的预测结果后,同样根据终止标准进行判断,不符合则继续上述过程,直到结果符合标准或已使用最终的Vision Transformer。

  • Training

  训练时,需保证DVT中所有级联Vision Transformer输出正确的预测结果,其优化目标为:

[图片上传失败...(image-6a8bb9-1663908579693)]

  其中,(x, y)为训练集D_{train}中的一个样本及其对应的标签,采用标准的交叉熵损失函数L_{CE}(·),而p_i表示第i个模型输出的softmax预测概率。

  • Transformer backbone

  DVT是一个通用且灵活的框架,可以嵌入到大多数现有的Vision Transformer模型(如ViT、DeiT和T2T-ViT)之中,提高其性能。

Feature and Relationship Reuse

  DVT的一个重要挑战是如何进行计算的复用。在使用的具有更多token的下游Vision Transformer时,直接忽略之前模型中的计算结果显然是低效的。虽然上游模型的token数量较少,但也提取了对预测有价值的信息。因此,论文提出了两种机制来复用学习到的深度特征和自注意力关系,仅增加少量的额外计算成本就能显着提高准确率。

  • Background

  介绍前,先重温一下Vision Transformer的基本公式。Transformer encoder由交替堆叠的多头自注意力(MSA)和多层感知器 (MLP)块组成,每个块的之前和之后分别添加了层归一化(LN)和残差连接。定义z_l\in R^{N\times D}表示第l层的输出,其中N是样本的token数,D是token的维度。需要注意的是,N=HW+1,对应H\times W图像块和可学习的分类token。假设Transformer共L层,则整个模型的计算可表示为:

[图片上传失败...(image-e3020e-1663908579693)]

  得到最终的结果z_L后,取其中的分类token通过LN层+全连接层进行最终预测。这里省略了position embedding的细节,论文没有对其进行修改。

  • Feature reuse

[图片上传失败...(image-e1414a-1663908579693)]

  DVT中的所有Transformer都具有相同的目标,即提取关键特征进行准确识别。 因此,下游模型应该在上游模型计算的深度特征的基础上学习才是最高效的,而不是从头开始提取特征。为此,论文提出了图3的特征复用机制,利用上游Transformer最后输出的结果z^{up}_L来生成下游模型每层的辅助embedding输入E_l

[图片上传失败...(image-3648e0-1663908579693)]

f_l:\mathbb{R}^{N\times D}\to \mathbb{R}^{N\times D^{'}} 由LN+MLP(\mathbb{R}^{D}\to \mathbb{R}^{D^{'}})开头,对上游模型输出进行非线性转换。转换后将结果reshape到原始图像中的相应位置,然后上采样并展平来匹配下游模型的token数量。一般情况下,使用较小的D^{'}以便快速生成f_l

  之后将E_l拼接到下游模型对应层的中间特征作为预测的先验知识,也就是将公式3替换为:

[图片上传失败...(image-57510d-1663908579693)]

E_l与中间特征z^{'}_l拼接,LN 的维度和MLP的第一层从D增加到D+D^{'}。 由于E_l是基于上游输出z^{up}_L生成的,token数少于z^{'}_l,它实际上为z^{'}_l中的每个token总结了输入图像的上下文信息。 因此,将E_l命名为上下文embedding。此外,论文发现不复用分类token对性能有提升,因此在公式5中将其填充零。

  公式4和5允许下游模型在每层灵活地利用z^{up}_L内的信息,从而最小化最终识别损失,这种特征重用方式也可以认为隐式地扩大了模型深度。

  • Relationship reuse

  Vision Transformer的关键在于自注意力模块能够整合整个图像的信息,从而有效地模拟图像中的长距离关系。通常情况下,模型需要在每一层学习一组注意力图来描述token之间的关系。除了上面提到的特征复用,论文认为下游模型还可以复用上游模型产生的自注意力图来进行优化。

  定义输入特征z_l,自注意力模块先通过线性变换得到query矩阵Q_l、key矩阵K_l和value矩阵V_l

[图片上传失败...(image-91e869-1663908579693)]

  其中,W^Q_lW^K_lW^V_l为权重矩阵。然后通过一个带有softmax的缩放点乘矩阵运算得到注意力图,最后根据注意力图来计算所有token的值:

[图片上传失败...(image-b10b54-1663908579693)]

  其中,dQK的点积结果维度,A_l\in \mathbb{R}^{N\times N}为注意力图。为了清楚起见,这省略了多头注意力机制的细节,多头情况下A_l包含多个注意力图。

  对于关系复用,先将上游模型所有层产生的注意力图(即A^{up}_l, l\in \{1,\cdots , L\})拼接起来:

[图片上传失败...(image-1e92b4-1663908579693)]

  其中,N^{up}N^{Att}_{up} 分别为上游模型中的toekn数和注意力图数,通常N^{Att}_{up} = N^H LN^H是多头注意力的head数,L是层数。

  下游的模型同时利用自己的token和A^{up}来构成注意力图,也就是将公式7替换为:

[图片上传失败...(image-7786fd-1663908579693)]

[图片上传失败...(image-9c7b27-1663908579693)]

  其中r_l(\cdot)是一个转换网络,整合A^{up}提供的信息来细化下游注意力图A_lr_l(\cdot)的架构如图5所示,先进行非线性MLP转换,然后上采样匹配下游模型的注意力图大小。

  公式9虽然很简单,但很灵活。有两个可以魔改的地方:

  • 由于下游模型中的每个自注意力模块可以访问上游模型的所有浅层和深层的注意力头,可以尝试通过可学习的方式来对多层的注意力信息进行加权整合。
  • 新生成的注意力图和复用注意力图直接相加,可以尝试通过可学习的方式来对两者加权。

[图片上传失败...(image-1d6cfb-1663908579693)]

  还需要注意的是,r_l(\cdot)不能直接使用常规上采样操作。如图5所示,假设需要将HW\times HW(H =W = 2)的注意力图映射上采样到H^{'}W^{'}\times H^{'}W^{'}(H^{'} =W^{'} = 3)的大小。由于每一行对应单个token与其他H\times W个token的关系,直接对注意力图上采样会引入混乱的数据。因此,需要先将行reshape为H\times W,然后再缩放到H^{'}W^{'}\times H^{'}W^{'},最后再展平为H^{'}W^{'}向量。

  • Adaptive Infernece

  如前面所述,DVT框架逐渐增加测试样本的token数量并执行提前终止,“简单”和“困难”图像可以使用不同的token数来处理,从而提高了整体效率。对于第i个模型产生的softmax预测p_i,将p_i的最大项max_j p_{ij}与阈值{\mu}_{i}进行比较。如果max_j p_{ij}\ge {\mu}_{i},则停止并采用p_i作为输出。否则,将使用更多token数更多的下游模型继续预测直到最后一个模型。

  阈值\{\mu_1, \mu_2, \cdots\}需要在验证集上求解。假设一个计算资源有限的批量数据分类场景,DVT需要在给定的计算预算B > 0内识别一组样本D_{val}。定义Acc(D_{val}, \{\mu_1, \mu_2, \cdots\})FLOPs(D_{val}, \{\mu_1, \mu_2, \cdots\})为数据集D_{val}上使用阈值\{\mu_1, \mu_2, \cdots\}时的准确度和计算成本,最优阈值可以通过求解以下优化问题得到:

[图片上传失败...(image-78b62f-1663908579693)]

  由于公式10是不可微的,论文使用遗传算法解决了这个问题。

Experiment


[图片上传失败...(image-c087ff-1663908579693)]

  ImageNet上的性能对比。

[图片上传失败...(image-481ebd-1663908579693)]

  推理性能对比。

[图片上传失败...(image-cbdec4-1663908579693)]

  CIFAR上对比DVT在不同模型规模的性能。

[图片上传失败...(image-9723be-1663908579693)]

  在ImageNet上与SOTA vision transformer提升方法的性能对比。

[图片上传失败...(image-36be9a-1663908579693)]

  基于DeiT的DVT性能对比。

[图片上传失败...(image-87f2e0-1663908579693)]

  复用机制的对比实验。

[图片上传失败...(image-65e8e3-1663908579693)]

  与类似的提前退出方法的性能对比。

[图片上传失败...(image-78e207-1663908579693)]

  复用机制提升的性能与计算量。

[图片上传失败...(image-aa5e1f-1663908579693)]

  复用机制实现细节的对比实验。

[图片上传失败...(image-1bd3f9-1663908579693)]

  难易样本的例子以及数量分布。

[图片上传失败...(image-5717b3-1663908579693)]

  不同终止标准的性能对比。

[图片上传失败...(image-c12515-1663908579693)]

  与自适应深度方法进行性能对比,自适应方法是在模型的不同位置插入分类器。

Conclusion


  论文主要处理Vision Transformer中的性能问题,采用推理速度不同的级联模型进行速度优化,搭配层级间的特征复用和自注意力关系复用来提升准确率。从实验结果来看,性能提升不错。



如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】

work-life balance.
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 194,761评论 5 460
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 81,953评论 2 371
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 141,998评论 0 320
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 52,248评论 1 263
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 61,130评论 4 356
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 46,145评论 1 272
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 36,550评论 3 381
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 35,236评论 0 253
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 39,510评论 1 291
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 34,601评论 2 310
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 36,376评论 1 326
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 32,247评论 3 313
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 37,613评论 3 299
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 28,911评论 0 17
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,191评论 1 250
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 41,532评论 2 342
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 40,739评论 2 335

推荐阅读更多精彩内容