Attention Is All You Need-谷歌的"自注意力"

上一篇文章记录了自然语言处理中的注意力机制,这篇文章分析一下google的一篇论文Attention Is All You Need

为什么不使用循环神经网络

其实早在google之前,facebook就在[1]中抛弃了RNN等提出了基于卷积的sequence to sequence模型。由于RNN中时间步之间存在依赖关系,因此各时间步无法并行运算,使得GPU并行计算的优势无法发挥。在同等神经元量级的情况下,RNN训练速度较CNN相比更慢。因此很多研究中希望用其他的网络类型来替代RNN,这可能也是google这篇论文的出发点之一。

主要结构

首先看一些论文中对attention的定义,对于Q\in R^{m*d_k}, K \in R^{n*d_k}, V \in R^{n*d_v},有:
(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V
这个定义表面上看和我们之前提到的注意力机制有很大的不同, Q, K, V 这几个矩阵也不知道什么意思。这里可以类比翻译任务,Q可以是看做是decoder的隐藏状态,而K可以看做encoder的隐藏状态,对于机器翻译任务K=V。在其他应用场景,例如在memory network中,K, V可能是分别用address的向量和用于修改memory的向量。这里QK^T实际上就是计算出一个权重,然后对V进行加权,也和之前提到的注意力结构一致。这里除以\sqrt {d_k}是避免QK乘积过大,softmax计算出结果出现上溢出。论文中给出的注意力计算流程如下:

(left) Scaled Dot-Product Attention. (right) Multi-Head Attention consists of several attention layers running in parallel.

其中左边是普通注意力机制,右边则是论文中提到的对普通注意力机制的一种改进。也就是Multi-Head Attention

Multi-Head Attention

Multi-Head Attention首先对Q, K,V分别乘不同的变换矩阵进行变换,并重复多次这样的操作。写成公式就是:
\begin{aligned} \text { MultiHead }(Q, K, V) &=\text { Concat }\left(\text { head }_{1}, \ldots, \text { head }_{\mathrm{h}}\right) W^{O} \\ \text { where head }_{\mathrm{i}} &=\text { Attention }\left(Q W_{i}^{Q}, K W_{i}^{K}, V W_{i}^{V}\right) \end{aligned}
论文中提到这种方式可以学习到不同子空间的特征。这里其实有一点CNN的意思,对相同的数据用不同的核进行处理。

自注意力机制

论文中还提出了这种注意力机制的一个应用,也就是自注意力机制。自注意力机制也就是Q=K=V的情况。使用这种方式,论文中提出了transformer结构,并采用了这种结构实现了sequence to sequence模型:

transformer结构

下面分别对其中的模块进行讲解:

PE(Positional Encoding)

从图中可以看出输入先经过了一个Positional Encoding(位置嵌入, PE)。其实PE在其他论文[2]也有体现。进行PE的的主要原因是transformer中并没有任何可以体现输入顺序的结构。对于NLP来说,词语的顺序是非常重要的。因此论文中使用了PE。论文中直接指定了公式:
P E_{(p o s, 2 i)}=\sin \left(p o s / 10000^{2 i / d_{\mathrm{matel}}}\right)
P E_{(p o s, 2 i+1)}=\cos \left(p o s / 10000^{2 i / d_{\mathrm{model}}}\right)
其中pos是指词语在句子中的位置,i是每一个词向量中的第i个元素,d_model是词向量的维度。论文中提到使用学得的词向量结果与这种方式相似,于是论文就采用了这种方式进行PE。

Position-wise Feed-Forward Networks

论文中对该层的描述是:

a fully connected feed-forward network, which is applied to each position separately and identically

其实也就是核大小为1的卷积网络。公式描述如下:
\mathrm{FFN}(x)=\max \left(0, x W_{1}+b_{1}\right) W_{2}+b_{2}
max(0, .)也就是relu激活函数。

除了上面提到的几个模块之外,在Multi-Head Attention以及Feed Forward层都使用了残差连接以及layer normalization。‘

Self-Attention的优势

论文中给出了Self-Attention的几种优势:

  1. 由于Multi-Head Attention每一层都可以并行计算,因此计算速度相比RNN有优势。
  2. 在长距离的依赖问题有优势。在RNN中,反向传递梯度容易弥散。虽然在LSTM中引入了遗忘门等记忆单元,但是仍然在长序列时出现输出只依赖于最近几个输入的情况。也就是当依赖路径变长时,当前输出受其他输入之间的影响逐渐变小。但是在Self-Attention中,每一个输入都与其他输入进行了attention,因此在每一个输入中都包含了来至于其他输入的信息,这样使得每一个输入与其他输入的依赖路径变得更短,更容易学得更长的依赖关系。

实现

下面是我使用keras对Multi-Head Attention的一个实现,源代码如下:

class MultiHeadAttention(Layer):
    """
        multi head attention 的实现
        参考论文: [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
    """

    def __init__(self, num_heads, projection_shape, d_model, **kwargs):
        self._num_heads = num_heads
        self._d_model = d_model
        self._projection_shape = projection_shape
        super(MultiHeadAttention, self).__init__(**kwargs)

    def __add_weight(self, shape, name):
        return self.add_weight(name=name, shape=shape, initializer='normal', trainable=True)

    def build(self, input_shape):
        """
           以下不包括 batch_size
            input_shape = [m, n, d_model]
            这里的 m, n, d_model 代表的是为映射前的输入大小:
            shape(Q) = [m, d_model]
            shape(K) = [n, d_model]
            shape(V) = [n, d_model]

            d_k, d_v 是指 multi-head-attention 映射之后:
                            Q|K|V      x     QW|KW|VW 
            shape(Q*WQ) = [m, d_model] x [d_model, d_k] = [m, d_k]
            shape(K*WK) = [n, d_model] x [d_model, d_k] = [n, d_k]
            shape(V*WV) = [n, d_model] x [d_model, d_v] = [n, d_v]
            一般来说, dk, dv < d_model 因为论文中指出映射实际上有降维的左右,这样可以加快计算速度

            另外:
            shape(W_O)  =  [h * d_v, d_model]  
            
            符号与 *Attention is All You Need* 一致
            
        """
        d_k, d_v = self._projection_shape
        head_weight = self.__add_weight
        self._QW = head_weight([self._d_model, d_k * self._num_heads], "Q")
        self._KW = head_weight([self._d_model, d_k * self._num_heads], "K")
        self._VW = head_weight([self._d_model, d_k * self._num_heads], "V")
        self._OW = head_weight([self._num_heads * d_v, self._d_model], "O")

        super(MultiHeadAttention, self).build(input_shape)

    def __attention(self, Q, K, V):
        batch_dot = backend.batch_dot
        d_k, _ = self._projection_shape
        raw_weights = batch_dot(Q, tf.transpose(K, [0, 2, 1])) / tf.sqrt(tf.constant(d_k, dtype=tf.float32))
        attention_weights = tf.nn.softmax(raw_weights, axis=2)  # 每一行进行 soft max
        return batch_dot(attention_weights, V)

    def _mul_every_batch_with(self, inputs, y):
        """
        将 inputs 的每一个 batch 与 y 相乘, 产生一个新的张量
        :param inputs: [batch_size, m, n]
        :param y: [n, x]
        :return:[batch_size, m, x]
        """
        return tool.mul_every_batch_with(inputs, y)

    def _multi_attention(self, Q, K, V):
        mul_every_batch_with = self._mul_every_batch_with
        # 多次线性映射,连接
        QP = mul_every_batch_with(Q, self._QW)
        KP = mul_every_batch_with(K, self._KW)
        VP = mul_every_batch_with(V, self._VW)
        attention = self.__attention(QP, KP, VP)
        return mul_every_batch_with(attention, self._OW)

    def call(self, inputs, **kwargs):
        Q, K, V = inputs
        return self._multi_attention(Q, K, V)

也可以直接点链接attention.py。上面的代码只单纯的实现了MultiHeadAttention,并没有实现transformer结构。

参考

[1] Convolutional Sequence to Sequence Learning
[2] End-To-End Memory Networks
[3] Attention Is All You Need

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 199,902评论 5 468
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 84,037评论 2 377
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 146,978评论 0 332
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 53,867评论 1 272
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 62,763评论 5 360
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,104评论 1 277
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,565评论 3 390
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,236评论 0 254
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,379评论 1 294
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,313评论 2 317
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,363评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,034评论 3 315
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,637评论 3 303
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,719评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,952评论 1 255
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,371评论 2 346
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 41,948评论 2 341