transformer语言模型原理解读

一、简介

基于假设:一个词在句子中的意思,与上下文(语境)有关。与哪些词有关呢?Transformer就是:利用点积将句子中所有词的影响当成权重都考虑了进去。

Papper模型图

Transform模型是与RNN和CNN都完全不同的思路。相比Transformer,RNN/CNN的问题:

  1. RNN序列化处理效率提不上去。理论上,RNN效果上问题不大。
  2. CNN感受野小。CNN只考虑卷积核大小区域,核内参数共享,并行/计算效率不是问题,但受限于核的大小,不能考虑整个上下文。

在并行方面,多头attention和CNN一样不依赖于前一时刻的计算,可以很好的并行,优于RNN。在长距离依赖上,由于self-attention是每个词和所有词都要计算attention,所以不管他们中间有多长距离,最大的路径长度也都只是1。可以捕获长距离依赖关系。

二、注意力机制

注意力实际就是加权

2.1 NLP中的注意力

以RNN做机器翻译为例,下两图[1]分别是有没有注意力:


没有注意力机制的机器翻译,翻译下一词时,只考虑源语言经过网络后最终的表达(编码/向量);而注意力机制是要考虑源语言中每(多)个词的表达(编码/向量)。

NLP中有个非常常见的一个三元组概念:Query、Key、Value,其中绝大部分情况Key=Value。在机器翻译中,Query是已经翻译出来的部分,Key和Value是源语言中每个词的表达(编码/向量),没有注意力时直接拿Query就去预测下一个词,注意力机制的计算就是用Query和Key计算出一组权重,赋权到Value上,拿Value去预测下一词。


翻译编码解码模型[2]


计算权重[2]

加权[2]

2.2 自注意力

自注意力模型就是Query“=”Key“=”Value,挖掘一个句子内部的联系。计算句子中每个字之间的互相影响/权重,再加权到句子中每个字的向量上。这个计算就是用了点积。

Query、Key、Value都来自同一个输入,但是经过3个不同线性映射(全连接层)得到,所以未必完全相等。

公式中`QK^T`是Query向量和Key向量做点积,为了防止点积结果数值过大,做了一个放缩(`d_k`是Key向量的长度),结果再经过一个softmax归一化成一个和为1的权重,乘到Value向量上。

attention可视化的效果(这里不同颜色代表attention不同头的结果,颜色越深attention值越大)。可以看到self-attention在这里可以学习到句子内部长距离依赖"making…….more difficult"这个短语。

2.2.1 点积(Dot-Product)

  • 两向量点积表示两个向量的相似度。
  • 点积还有一个重要的特点是没有参数。

点积也叫点乘,一维点积用几何表示是: `\vec{a}\bullet\vec{b}=|\vec{a}||\vec{b}|\cos\theta` 。与我们常用的余弦相识度/夹角作用一样,与两向量的相似程度成正比。

2.2.2 具体计算过程:

假设我们句子长度设为512,每个单词embedding成256维。

  1. `QK^T`Query与Key点积。

Pytorch代码:

attn = torch.bmm(q, k.transpose(1, 2))
  1. scale放缩、softmax归一化、dropout随机失活/置零
    Pytorch代码:
attn = attn / self.temperature
if mask is not None:
    attn = attn.masked_fill(mask, -np.inf)
attn = self.softmax(attn)
attn = self.dropout(attn)
  1. 将权重矩阵加权到Value上,维度未变化。

Pytorch代码:

output = torch.bmm(attn, v)

2.3 多头注意力

并不是将长度是512的句子整个做点积自注意力,而是将其“拆”成h份,没份长度为512/h,然后每份单独去加权注意力再拼接到一起,Q、K、V分别拆分。

“拆”的过程是一个独立的(different)、可学习的(learned)线性映射。实际实现可以是h个全连接层,每个全连接层输入维度是512,输出512/h;也可以用一个全连接,输入输出均为512,输出之后再切成h份。

多头能够从不同的表示子空间里学习相关信息。

在两个头和单头的比较中,可以看到单头"its"这个词只能学习到"law"的依赖关系,而两个头"its"不仅学习到了"law"还学习到了"application"依赖关系。

Pytorch实现:

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        self.w_qs = nn.Linear(d_model, n_head * d_k)
        self.w_ks = nn.Linear(d_model, n_head * d_k)
        self.w_vs = nn.Linear(d_model, n_head * d_v)
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))  
        ...
    def forward(self, q, k, v, mask=None):

        sz_b, len_q, _ = q.size()
        sz_b, len_k, _ = k.size()
        sz_b, len_v, _ = v.size()

        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv

三、位置编码(Positional Encoding)

因为transformer没有RNN和CNN,为了考虑位置信息,论文中直接将全局位置编号加到Embedding向量每个维度上。
Pytorch代码:

        # -- Forward
        enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos)

另外,论文中位置编码还利用了sin/cos正余弦函数考虑周期性和归一化。

四、残差和前馈(Feed Forward)

4.1 为什么残差[3]

网络的深度为什么重要?

因为CNN能够提取low/mid/high-level的特征,网络的层数越多,意味着能够提取到不同level的特征越丰富。并且,越深的网络提取的特征越抽象,越具有语义信息。

为什么不能简单地增加网络层数?

对于原来的网络,如果简单地增加深度,会导致梯度弥散或梯度爆炸。

对于该问题的解决方法是正则化初始化和中间的正则化层(Batch Normalization),这样的话可以训练几十层的网络。

虽然通过上述方法能够训练了,但是又会出现另一个问题,就是退化问题,网络层数增加,但是在训练集上的准确率却饱和甚至下降了。这个不能解释为overfitting,因为overfit应该表现为在训练集上表现更好才对。
退化问题说明了深度网络不能很简单地被很好地优化。
作者通过实验:通过浅层网络+ y=x 等同映射构造深层模型,结果深层模型并没有比浅层网络有等同或更低的错误率,推断退化问题可能是因为深层的网络并不是那么好训练,也就是求解器很难去利用多层网络拟合同等函数。

怎么解决退化问题?

深度残差网络。如果深层网络的后面那些层是恒等映射,那么模型就退化为一个浅层网络。那现在要解决的就是学习恒等映射函数了。 但是直接让一些层去拟合一个潜在的恒等映射函数H(x) = x,比较困难,这可能就是深层网络难以训练的原因。但是,如果把网络设计为H(x) = F(x) + x,如下图。我们可以转换为学习一个残差函数F(x) = H(x) - x. 只要F(x)=0,就构成了一个恒等映射H(x) = x. 而且,拟合残差肯定更加容易。

4.2 前馈

每个attention模块后面会跟两个全连接,中间加了一个Relu激活函数,公式表示:

也可用两个核为1的CNN层代替。
两个全连接是512->2048->512的操作。原因未详细介绍。

五、训练-模型的参数在哪里

transformer的核心点积是没有参数,transform结构的训练,会优化的参数主要在:

  1. 嵌入层-Word Embedding
  2. 前馈(Feed Forward)层
  3. 多头注意力中的“切片”操作(映射成多个/头小向量)实际是一个全连接层(线性映射矩阵),以及多头输出拼接结果(Concat)后会经过一个Linear全连接层。这两个全连接层也是残差块有意义的地方,如果没有这一层,那这个注意力机制中就没有参数,残差就没有意义了。

六、参考文献

[1]. Neural Machine Translation by Jointly Learning to Align and Translate
[2]. 残差的解读

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,293评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,604评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,958评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,729评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,719评论 5 366
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,630评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,000评论 3 397
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,665评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,909评论 1 299
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,646评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,726评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,400评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,986评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,959评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,197评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 44,996评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,481评论 2 342

推荐阅读更多精彩内容