第49章 ResNet模型及其架构组件

前面在讲解卷积神经网络时介绍了VGG模型,随着VGG模型的成功,更深、更宽、更复杂的网络似乎成为卷积神经网络模型的主流。但同时也带来了问题。

神经网络退化

卷积神经网络能够用来侦测对象的低、中、高特征,网络的层数越多,就越能够提取到更多不同的特征越。同时,通过还原镜像还发现,越深的网络提取的特征越抽象,越具有语义信息。

这就产生了一个疑问,是不是可以单纯地听过增加网络模型的深度和宽度,即增加更多的隐藏层和每个层中的神经元去得到更好的结果呢?

答案是否定的。实验发现,随着卷积神经网络深度的增加,在训练集上,准确率很难达到100%正确,甚至出现了下降。

这似乎不能简单地解释为卷积神经网络的性能下降,因为卷积神经网络加深的基础理论就是越深越好。如果强行解释为过拟合,似乎也不能解释准确率下降的问题,因为如果产生过拟合,那么在训练集上卷积神经网络应该表现得更好才对。

这个问题被称为“神经网络退化”。神经网络退化问题的产生说明了卷积神经网络不能够被简单地使用堆积层数的方式进行优化。

2015年,随着152层深的ResNet问世,使得训练深度达到数百甚至数千层的网络成为可能,而且性能依然优异。

本章主要介绍ResNet及其变种。

ResNet基础原理

ResNet的出现改变了VGG系列所带来的思维定势,创造性地采用模块化的思维替代整体的卷积层,通过一个个魔魁啊的堆叠来替代不断增加的卷积层。对ResNet的研究和不断改进成了过去几年中计算机视觉和深度学习领域最具有突破性的工作。并且由于其表征能力强,ResNet在图像分类任务以外许多计算机诗句应用上也取得了巨大的性能提升,例如目标检测和人脸识别。

ResNet诞生

卷积神经网络的实质就是无限拟合一个符合对应目标的函数。而根据泛逼近定理(Universal Approximation Theorem),如果给定足够的容量,一个单层的前馈网络就可以表示任何函数。但是,这个层可能非常大,而且网络容易过拟合数据。因此,学术界有一个共同的认识,就是网络架构需要更深。

但是,研究发现只是简单地将层堆叠在一起,增加网络的深度并不会起太大的作用。这是由于梯度消失(Vanishing Gradient)导致深层的网络很难训练。因为梯度反向传播到前一层,重复相乘可能是梯度无穷小,结果随着网络层数越深,其性能越趋于饱和,甚至开始迅速下降。

图1 网络层数越深,其性能越趋于饱和

在ResNet之前,已经出现好几种处理梯度消失问题的方法,但是没有一个方法能够真正解决这个问题。何恺明等人与2015年发表的论文《用于图像识别的深度残差学习》(Deep Residual Learning for Image Recognition)中认为,堆叠的层不应该降低网络的性能,可以简单地在当前网络上堆叠映射层(不处理任务的层),并且所得到的架构性能不变。
f'\left( x \right) = \left\{ \begin{array}{cl} x & : \ f\left( x \right) = 0 \\ f\left( x \right) + x & : \ f\left( x \right) \neq 0 \end{array} \right.
当f(x)为0时,f’(x)等于x;当f(x)不为0时,所获得的f’(x)性能要优于单纯地输入x。公式表明,叫声的模型所产生的训练误差不应该比较浅的模型误差更高。假设让堆叠的层拟合一个残差映射(Residual Mapping),要比让它们直接拟合所需的底层映射更容易。

从下图可以看到,残差映射与传统直接相连的卷积神经网络相比,最大的变化就是加入了一个恒等映射层,即y = x层,其主要作用是使得网络随着深度的增加而不会产生权重衰减、梯度衰减或者消失这些问题。

图2 残差映射的恒等映射层

上图中,F(x)表示残差,F(x) + x是最终的映射输出,因此可以得到网络的最终输出为H(x) = F(x) + x。由于网络框架中有2个卷积层和2个ReLU函数,因最终的输出结果可以表示为,
H_{1}\left( x \right) = relu_{1}\left( w_{1}\times x \right) H_{2}\left( x \right) = relu_{2}\left( w_{2}\times H_{1}\left( x \right) \right) H\left( x \right) = H_{2} + x

图3 残差网络映射输出计算方式

其中H_{1}是第一层的输出,而H_{2}是第二层输出。这样在输入与输出有相同维度时,可以使用直接输入的形式传递到框架的输出层。

ResNet整体结构图与VGG比较如下图所示,

图4 ResNet整体结构图与VGG比较

上图是19层VGG、34层普通结构的神经网络以及24层ResNet神经网络的对比图。通过验证得知,在使用了ResNet的架构后发现,层数不断增加导致的训练集上误差增加的现象被消除了,ResNet网络的训练误差会随着层数增加而逐渐减小,并且在测试集上的表现也会更好。下图分别是18层和34层普通和残差网络训练后误差的下降曲线。

图5 18层和34层普通和残差网络训练后误差的下降曲线

除了用于讲解的二层残差学习单元,实际上更多的是使用[1, 1]结构的三层残差学习单元,

图6 二层残差学习单元和三层残差学习单元

这是借鉴了NiN(Network in Network)模型的思想,在二层残差单元中包含1个[3, 3]卷积层的基础上,更包含了2个[1, 1]大小的卷积,放在[3, 3]卷积层的前后,执行先降维再升维的操作。

无论采用哪种连接方式,ResNet的核心是引入一个“身份捷径连接”(Identify Shortcut Connection),直接跳过一层或多层将输入层与输出层进行连接。实际上,ResNet并不是第一个利用Identify Shortcut Connection的方法,较早期就有相关研究人员在卷积神经网络中引入了“门控短路电路”,即参数化的门控系统允许各种信息通过网络通道。如下图,

图7 Identify Shortcut Connection

但并不是所有加入了“Shortcut”的卷积神经网络都会提高传输效果。在后续的研究中,有不少研究人员对残差块进行了改进,但是和遗憾,它并不能获得性能上的提高。

目前图(a) original性能最好。

ResNet架构组件

可能现在已经迫不及待想要自定义自己的残差网络了。在构建自己的残差网络之前,需要被准备好相关的程序设计工具。上一章介绍了jax.examle_libraries包,为了加深印象,ResNet最终的代码实战会使用这些JAX提供的组件,这些组件就是要使用的工具,即已经设计好结构并可以直接使用的代码。

  • jax.exmaple_libraries.stax.Conv,卷积核。从模型上看,需要更改的内容很少,即卷积核的大小、输出通道数以及所定义的卷积层的名称。
  • jax.example_libraries.stax.BatchNorm,对数据进行批标准化,这是使用批标准化对数据进行处理。
  • jax.example_libraries.MaxPool,最大池化层。
  • jax.example_libraries.AvgPool,平均池化层。

这些是在ResNet模型单元中所用到的基本工具,有了这些工具,就可以直接构建ResNet模型单元。下面将对实现函数进行详细介绍。

jax.exmaple_libraries.stax.Conv卷积层

jax.exmaple_libraries.stax.Conv是卷积计算的实现,其函数原型如下,

import functools

def GeneralConvolution(dimension_numbers, out_channel, filter_shape, strides = None, padding = "VALID", weights_init = None, biases_init = normal(1e-6)):
    
    # ...
    return init_fun, apply_fun

需要说明的是,卷积Convolution的实现在JAX中是挺贵2个步骤完成,第一步定义卷积主函数,也就是普通函数的卷积,自后对主函数进行格式化包装,生成符合计算需求的函数计算主体部分。

下面用一个例子说明jax.exmaple_libraries.stax.Conv的使用方法,


import jax

def model():
    
    filter_number = 64
    filter_size = (3, 3)
    strides = (2, 2)
    
    jax.example_libraries.stax.Conv(filter_number, filter_size, strides)
    jax.example_libraries.stax.Conv(filter_number, filter_size, strides, padding = "SAME")
    
def train():
    
    model()
    
def main():
    
    train()

jax.example_libraries.stax.BatchNorm批标准化

BatchNormalization是目前最常见的数据标准化方法,也是批量标准化方法。输入数据经过处理之后能够显著地加速训练速度,并且减少过拟合出现的可能性。其函数原型如下所示,

def BatchNormalization(axis = (0, 1, 2), epsilon = 1e-5, center = True, sca>
    
        ...
        
        return init_fun, apply_fun

BatchNormalization在jax.example_libraries.stax.BatchNorm调用时比较简单,直接初始化,一般不用传递参数。

jax.example_libraries.stax.Dense全连接层

Dense是全连接层,其在使用时需要输入分类的类别数,如下所示,

def Dense(out_dimensions, weights_init = glorot_normal(), biases_init = nromal(>
    
    ...
    
    return init_fun, apply_fun

其中out_dimensions需要在类被初始化的时候定义,如下所示,

def Dense(out_dimensions = 10, weights_init = glorot_normal(), biases_init = nromal(>
    
    ...
    
    return init_fun, apply_fun
Pooling池化层

Pooling即池化层。stax模块包含了多个池化方法,这几个池化方法都是类似的。包含jax.example_libraries.stax.MaxPool、jax.example_libraries.stax.SumPool、jax.example_libraries.staxAvgPool,分别代表最大、求和和平均池化方法。 下面以常用的jax.example_libraries.stax.AvgPool为例进行解释,

def AvgPool(window_shape, strides = None, padding = "VALID", spec = None):
    
    ...
    
    return init_fun, apply_fun

可以看到,该方法需要输入3个参数,分别是池化窗口大小windw_shape、池化步进步长strides以及填充方式padding。

import jax

def model():
    
    filter_number = 64
    filter_size = (3, 3)
    strides = (2, 2)
    
    jax.example_libraries.stax.Conv(filter_number, filter_size, strides)
    jax.example_libraries.stax.Conv(filter_number, filter_size, strides, padding = "SAME")
    
    window_shape = (3, 3)
    strides = (2, 2)
    
    jax.example_libraries.stax.AvgPool(window_shape = window_shape, strides = strides)
    
def train():
    
    model()
    
def main():
    
    train()

除了上面这写谈及的类,还有其他构成神经网络的类,感兴趣可以自行尝试。

另外,在此说明,不同版本的JAX,其命名空间可能不同,比如早前版本jax.experimental.stax,截止到目前是jax.example_libraries.stax。

jax.example_libraries.stax特殊的类

下面介绍jax.example_libraries.stax一些特殊类。

jax.example_libraries.stax.FanOut数据复制

def FanOut(number):
    
    init_fun = lambda prng, input_shape:([input_shape] * number, ())
    apply_fun = lambda params, inputs, **kwargs: [inputs] * number
    
    return init_fun, apply_fun

def model():
    
    FanOut(number = 2)

从类构造函数函数原型可以看出,这个类是对输入的数据进行复制,接受一个参数number,代表复制的份数。

jax.example_libraries.stax.FanInSum数据求和

def FanInSum():
    
    init_fun = lambda prng, input_shape: (input_shape[0], ())
    apply_fun = lambda  params, inputs, **kwargs: sum(inputs)
    
    return init_fun, apply_fun

def model():
    
    fanInsum = FanInSum()

从类构造函数函数原型可以看出,这个类是对输入的数据进行求和运算。

jax.example_libraries.stax.FanInConcat数据串接

import jax

def FanInConcat(axis = -1):
    
    def init_fun(prng, input_shape):
        
        ax = axis &% len(input_shape[0])
        concat_size = sum(shape[ax] for shape in input_shape)
        
        output_shape = input_shape[0][: ax] + (concat_size, ) + input_shape[0][>
        
        return output_shape, ()
    
    def apply_fun(params, inputs, **kwargs):
        
        return jax.numpy.concatenate(inputs, axis)
    
    return init_fun, apply_fun

从类构造函数函数原型可以看出,这个类是对输入的数据在最后一维进行串接(Concatenate),从而形成一个新的数据。

jax.example_libraries.stax.Identifty
def Identity():
    
    init_fun = lambda prng, inputt_shape: (input_shape, ())
    apply_fun = lambda params, inputs, **kwargs: inputs
    
    return init_fun, apply_fun

本类的作用时对输入的数据进行完整的输出。

结论

本章简单介绍了ResNet的背景及原理,以及相对于普通卷积神经网络ResNet的改进。另外也承接上一章,详细介绍了JAX本身提供的jax.example_libraries.stax命名空间下深度学习基本组件,以及jax.example_libraries.stax下特殊类。

本章是理论准备阶段,为最终的代码实战做好理论准备。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,098评论 5 476
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,213评论 2 380
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 149,960评论 0 336
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,519评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,512评论 5 364
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,533评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,914评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,574评论 0 256
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,804评论 1 296
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,563评论 2 319
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,644评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,350评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,933评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,908评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,146评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,847评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,361评论 2 342

推荐阅读更多精彩内容