一、简介
基于假设:一个词在句子中的意思,与上下文(语境)有关。与哪些词有关呢?Transformer就是:利用点积将句子中所有词的影响当成权重都考虑了进去。
Transform模型是与RNN和CNN都完全不同的思路。相比Transformer,RNN/CNN的问题:
- RNN序列化处理效率提不上去。理论上,RNN效果上问题不大。
- CNN感受野小。CNN只考虑卷积核大小区域,核内参数共享,并行/计算效率不是问题,但受限于核的大小,不能考虑整个上下文。
在并行方面,多头attention和CNN一样不依赖于前一时刻的计算,可以很好的并行,优于RNN。在长距离依赖上,由于self-attention是每个词和所有词都要计算attention,所以不管他们中间有多长距离,最大的路径长度也都只是1。可以捕获长距离依赖关系。
二、注意力机制
注意力实际就是加权。
2.1 NLP中的注意力
以RNN做机器翻译为例,下两图[1]分别是有没有注意力:
没有注意力机制的机器翻译,翻译下一词时,只考虑源语言经过网络后最终的表达(编码/向量);而注意力机制是要考虑源语言中每(多)个词的表达(编码/向量)。
NLP中有个非常常见的一个三元组概念:Query、Key、Value,其中绝大部分情况Key=Value。在机器翻译中,Query是已经翻译出来的部分,Key和Value是源语言中每个词的表达(编码/向量),没有注意力时直接拿Query就去预测下一个词,注意力机制的计算就是用Query和Key计算出一组权重,赋权到Value上,拿Value去预测下一词。
翻译编码解码模型[2]
计算权重[2]
加权[2]
2.2 自注意力
自注意力模型就是Query“=”Key“=”Value,挖掘一个句子内部的联系。计算句子中每个字之间的互相影响/权重,再加权到句子中每个字的向量上。这个计算就是用了点积。
Query、Key、Value都来自同一个输入,但是经过3个不同线性映射(全连接层)得到,所以未必完全相等。
公式中是Query向量和Key向量做点积,为了防止点积结果数值过大,做了一个放缩(是Key向量的长度),结果再经过一个softmax归一化成一个和为1的权重,乘到Value向量上。
attention可视化的效果(这里不同颜色代表attention不同头的结果,颜色越深attention值越大)。可以看到self-attention在这里可以学习到句子内部长距离依赖"making…….more difficult"这个短语。
2.2.1 点积(Dot-Product)
- 两向量点积表示两个向量的相似度。
- 点积还有一个重要的特点是没有参数。
点积也叫点乘,一维点积用几何表示是: 。与我们常用的余弦相识度/夹角作用一样,与两向量的相似程度成正比。
2.2.2 具体计算过程:
假设我们句子长度设为512,每个单词embedding成256维。
- Query与Key点积。
Pytorch代码:
attn = torch.bmm(q, k.transpose(1, 2))
- scale放缩、softmax归一化、dropout随机失活/置零
Pytorch代码:
attn = attn / self.temperature
if mask is not None:
attn = attn.masked_fill(mask, -np.inf)
attn = self.softmax(attn)
attn = self.dropout(attn)
- 将权重矩阵加权到Value上,维度未变化。
Pytorch代码:
output = torch.bmm(attn, v)
2.3 多头注意力
并不是将长度是512的句子整个做点积自注意力,而是将其“拆”成h份,没份长度为512/h,然后每份单独去加权注意力再拼接到一起,Q、K、V分别拆分。
“拆”的过程是一个独立的(different)、可学习的(learned)线性映射。实际实现可以是h个全连接层,每个全连接层输入维度是512,输出512/h;也可以用一个全连接,输入输出均为512,输出之后再切成h份。
多头能够从不同的表示子空间里学习相关信息。
在两个头和单头的比较中,可以看到单头"its"这个词只能学习到"law"的依赖关系,而两个头"its"不仅学习到了"law"还学习到了"application"依赖关系。
Pytorch实现:
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
...
def forward(self, q, k, v, mask=None):
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
三、位置编码(Positional Encoding)
因为transformer没有RNN和CNN,为了考虑位置信息,论文中直接将全局位置编号加到Embedding向量每个维度上。
Pytorch代码:
# -- Forward
enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos)
另外,论文中位置编码还利用了sin/cos正余弦函数考虑周期性和归一化。
四、残差和前馈(Feed Forward)
4.1 为什么残差[3]
网络的深度为什么重要?
因为CNN能够提取low/mid/high-level的特征,网络的层数越多,意味着能够提取到不同level的特征越丰富。并且,越深的网络提取的特征越抽象,越具有语义信息。
为什么不能简单地增加网络层数?
对于原来的网络,如果简单地增加深度,会导致梯度弥散或梯度爆炸。
对于该问题的解决方法是正则化初始化和中间的正则化层(Batch Normalization),这样的话可以训练几十层的网络。
虽然通过上述方法能够训练了,但是又会出现另一个问题,就是退化问题,网络层数增加,但是在训练集上的准确率却饱和甚至下降了。这个不能解释为overfitting,因为overfit应该表现为在训练集上表现更好才对。
退化问题说明了深度网络不能很简单地被很好地优化。
作者通过实验:通过浅层网络+ y=x 等同映射构造深层模型,结果深层模型并没有比浅层网络有等同或更低的错误率,推断退化问题可能是因为深层的网络并不是那么好训练,也就是求解器很难去利用多层网络拟合同等函数。
怎么解决退化问题?
深度残差网络。如果深层网络的后面那些层是恒等映射,那么模型就退化为一个浅层网络。那现在要解决的就是学习恒等映射函数了。 但是直接让一些层去拟合一个潜在的恒等映射函数H(x) = x,比较困难,这可能就是深层网络难以训练的原因。但是,如果把网络设计为H(x) = F(x) + x,如下图。我们可以转换为学习一个残差函数F(x) = H(x) - x. 只要F(x)=0,就构成了一个恒等映射H(x) = x. 而且,拟合残差肯定更加容易。
4.2 前馈
每个attention模块后面会跟两个全连接,中间加了一个Relu激活函数,公式表示:
也可用两个核为1的CNN层代替。
两个全连接是512->2048->512的操作。原因未详细介绍。
五、训练-模型的参数在哪里
transformer的核心点积是没有参数,transform结构的训练,会优化的参数主要在:
- 嵌入层-Word Embedding
- 前馈(Feed Forward)层
- 多头注意力中的“切片”操作(映射成多个/头小向量)实际是一个全连接层(线性映射矩阵),以及多头输出拼接结果(Concat)后会经过一个Linear全连接层。这两个全连接层也是残差块有意义的地方,如果没有这一层,那这个注意力机制中就没有参数,残差就没有意义了。
六、参考文献
[1]. Neural Machine Translation by Jointly Learning to Align and Translate
[2]. 残差的解读