首先需要明确的是,Transformer是一个翻译模型。与之前主流的翻译模型相比,transformer的依然是一个encoder-decoder结构,改变的主要是encoder和decoder内部的组成,改变结构带来的优势是使得模型可以并行化训练。Transformer的结构如下:
图中左边就是encoder部分,右边就是decoder部分。
下面来具体看下两个部分的组成。首先看encoder部分,如下图:
transformer中encoder的主要组成.png
encoder主要结构组成包括一个self-attention模块和一个FFN模块。之所以是self-attention,是因为输入的Q、K、V都是输入句子对应的embedding矩阵。此外还加上了resnet中的残差结构,减少了梯度消失的风险,使得模型更好训练。
再来看看decoder部分,如下图:
transformer中decoder的主要组成.png
decoder主要结构组成包括一个self-attention模块,一个encoder-decoder-attention模块和一个FFN模块。其中self-attention模块和FFN模块和encoder部分中的相同,而encoder-decoder-attention模块在结构上也和self-attention模块一样,不同点在于encoder-decoder-attention模块中的K和V是encoder部分的输出,Q是自身self-attention模块的输出。
前文已经提到,transformer的主要改进在于可以并行化训练。之前主流的翻译模型中encoder和decoder组成都是RNN,但是RNN是不能并行化训练的,只有前一时刻训练结束才能进行当前时刻的训练(就像GBDT一样,每棵树的训练是依赖前一棵树的)。但是transformer中attention结构不一样,它是将每个时刻的输入信息之间的距离视为1,任意两个时刻的输入信息是可以直接交互运算的。attention中信息之间的交互计算是通过矩阵运算实现的,而矩阵运算是可以很好的进行并行计算的。RNN和attention中信息的处理如下图:
RNN中t时刻信息与t-2时刻信息的交互必须要先经过t-1时刻,但是attention中t时刻和t-2时刻之间可以直接交互。没有了时序上的限制,attention结构就可以进行并行化计算了。同时attention结构另外一个优势在于长期依赖问题得到缓解,比如RNN中t+2时刻信息与t-2时刻的信息依赖关系之间相差4步,但是attention中两者可以直接计算,因此可以认为两者相距为1。不过尽管attention中长期依赖问题得到缓解,但是也带来另外一个问题,就是对位置信息的忽略。时序数据中数据的先后关系是很重要的,比如一句话中两个单词的颠倒可能就会导致完全不同的含义。Transformer的补救措施是给输入信息加上了位置编码,如下图:
也就是对每个单词对应的embedding向量再加上位置编码向量。位置编码公式如下:其中,是信息在序列中的位置,而是位置编码的维度序号,且。为什么要用这个函数呢?文中给的解释是:
We chose this function because we hypothesized it would allow th emodel to easily learn to attend by relative positions, since for any fixed offset k, can be represented as a linear function of
这里我给下个人解释,位置编码后会得到一个矩阵,行数就是序列的长度,列数就是embedding维度。每一列上的值都是从同一个频率的正弦函数上取出的,每一行就是一个单词的位置编码向量,这个向量是由多个不同频率的正弦函数组成。由于相同列都是从同一频率的正弦函数上取的值,所以不同行向量之间可以线性表示。不过为什么满足这样特性的编码函数就是好的函数我还是不太明白,这里挖个坑,以后弄懂了再来写。
在PyTorch实现代码中实现位置编码时并不是直接实现上述公式,而是做了点改变。改变在于将公式中改为。个人猜测这样改变可能会节约时间吧。
有一点需要注意:transformer中的并行化计算只在encoder中进行,在decoder中是不可以并行化计算的。(经评论指正,在训练阶段decoder中也是可以进行并行化计算的,只是在预测阶段不可以进行并行化计算)decoder中attention结构的K和V是不变的,但是output embedding是变化的,依赖于上一时刻decoder的输出。
接下来再来讲下transformer中的attention结构。Transformer中最基本的attention结构是Scaled Dot-Product Attention,结构如下:
Scaled Dot-Product Attention结构.png
其中,Q和K的列数是一样的,而K和V的行数是一样的。Attention函数可以理解为将一个query和一个(key, value)对映射成一个输出,query和key进行矩阵相乘得到相应的weight,然后再将weight和value进行矩阵相乘得到最终的输出。Transformer中用的是Multi-Head Attention,其实也是Scaled Dot-Product Attention组成的,Multi-Head Attention结构图如下:
Multi-Head Attention结构.png
其实Multi-Head Attention就是由多组Scaled Dot-Product Attention组成,transformer中先将高维的Q、K、V映射到多个低维空间中,在每个低维空间中进行attention操作得到一个低维的输出,然后再将这些低维输出拼接起来得到和原始维度一样的输出。这样做的好处在于不增加计算量的情况下使得attention的效果更好了。下面是Multi-Head Attention的PyTorch实现代码:
Multi-Head Attention实现代码.png
其中self.linears的作用就是将原始的Q、K、V映射到不同的低维空间中。