简单的NN的基础概念在这里不多说了。RNN的中文名是卷积神经网络,是一种广泛应用的神经网络类型。RNN常常用于sequence data,所以对于文本(词序列)、或者vedio(每一帧每一帧来看)等数据有不错的效果。这里主要总结一下RNN和LSTM的一些相关笔记。
(Note: Some of the contents (images, slides) are taken or modified from Stanford CS321 slides)
一. RNN Basis
RNN的基本构造:
具体来说:
这里x_t是在t时刻的输入,h_t是hidden state。
二、Vanilla RNN Gradient Flow
传统RNN的一个问题是在训练过程中,反向传播会出现gradient explode/vanish的问题,用中文来理解就是梯度消失或者梯度爆炸。具体来看:
如上图所示,这里是一个RNN的cell,灰色箭头是正向传播,红色箭头是反向传播。以这个cell的反向传播为例,根据复合梯度的乘法原理,在计算反向传播时会对d(h)乘以W。那么对于整个RNN来说:
如图所示,对于有四个node的网络来说,从h_4到h_0的反向传播过程中会有 D(h_0) x W x W x W,也就是乘以W的四次方。这样就会产生梯度爆炸或者梯度消失的问题。因为RNN中的一个常用机制是权值共享机制,也就是说W是一样的;那么如果W矩阵的最大特征值 > 1, 就会产生题都爆炸的情况,也就是梯度传到比较靠前的状态节点比如t = 0的时候,因为乘了W的n次方,梯度此时会非常大;反之,如果W的最大特征值 < 1,就会使得梯度近似于0。(试想一下W是一个标量,如果W大于1,那梯度是不断增加的过程;如果W小于1,梯度会越来越小最终趋于0)
当然,对于gradient explod我们可以采用clipping的方法,即如果梯度值 G > max_value(eg: 50)时,我们可以令G = 50。但是对于gradient vanishing的问题,并没有很好的解决方法。
三、 LSTM
针对RNN的题都爆炸/消失,研究者们提出来LSTM。LSTM是long short time memory的简称。
不同于RNN,在lstm中我们除了hidden state以外,还增加了c, i, f, o, g几个变量。具体来说:
c:cell state,不同于hidden state,cell state 是不会和cell外界有接触的;
i:input gate,决定是否写入cell;
f: forget gate: 决定是否消除cell;
o: output gate, 决定要反映多少的信息;
g: (没有一个特定的名字),用来决定要写入cell多少信息。
具体来说:
对于一个cell来说,它的结构如下图:
如图中所示:我们现在的state包括h和c两个变量。在t时刻,输入x(t)和t-1时刻的状态h(t-1)通过stack合成( h(t-1),x(t) )后与W权重相乘,之后生成f, i, g, o四个gate。
c(t-1)通过一系列与当前gate的操作后生成c(t), 所以c(t)只取决于上一时刻的cell state和gate值,与外界信息是没有直接接触的。直观理解,c(t)一方面取决于c(t-1);一方面取决于i和g。通过刚才对这四个gate的定义,f和c(t-1)的乘法得到了究竟我们想对c(t-1)的信息留下多少;i和g的乘法得到了我们对当前时刻的新信息希望记住多少。
另一方面,当前时刻的hidden state,h(t)取决于output gate与c(t)的操作。
对于LSTM中的gradient flow,如下图所示:
同样,红色是方向传播,灰色箭头是正想传播。因为此时cell state c的存在,在反向传播中只有加法和乘法的操作,一定程度上弥补了hidden state h的梯度消失;另一方面,我们要注意c的乘法操作中是按位相乘而不是矩阵相乘;按位相乘要比矩阵相乘在梯度消失的问题上友好的多。
好了,最后总结一下吧:
1. RNN是一个很有用的神经网络结构;适用于squence data;
2. Vanilla RNN很简单但是往往工作效果一般;
3. RNN有梯度消失(使用LSTM)或者梯度爆炸(用clipping解决)的问题。
同时,LSTM虽然看上去复杂,但是现在有很多现有的计算平台可以调用;比如tensorflow,torch,keras等等。
PS:本来想介绍一下基于RNN的language model,但是实在太懒了,贴个视频过来,不过youtube需要翻墙。https://www.youtube.com/watch?v=Keqep_PKrY8&list=PL3FW7Lu3i5Jsnh1rnUwq_TcylNr7EkRe6&index=8
注:所有图片和部分内容均来自于Stanford CS231 slides。
有错误欢迎指正;
联系方式: happyderekhu@163.com。