前面在讲解卷积神经网络时介绍了VGG模型,随着VGG模型的成功,更深、更宽、更复杂的网络似乎成为卷积神经网络模型的主流。但同时也带来了问题。
神经网络退化
卷积神经网络能够用来侦测对象的低、中、高特征,网络的层数越多,就越能够提取到更多不同的特征越。同时,通过还原镜像还发现,越深的网络提取的特征越抽象,越具有语义信息。
这就产生了一个疑问,是不是可以单纯地听过增加网络模型的深度和宽度,即增加更多的隐藏层和每个层中的神经元去得到更好的结果呢?
答案是否定的。实验发现,随着卷积神经网络深度的增加,在训练集上,准确率很难达到100%正确,甚至出现了下降。
这似乎不能简单地解释为卷积神经网络的性能下降,因为卷积神经网络加深的基础理论就是越深越好。如果强行解释为过拟合,似乎也不能解释准确率下降的问题,因为如果产生过拟合,那么在训练集上卷积神经网络应该表现得更好才对。
这个问题被称为“神经网络退化”。神经网络退化问题的产生说明了卷积神经网络不能够被简单地使用堆积层数的方式进行优化。
2015年,随着152层深的ResNet问世,使得训练深度达到数百甚至数千层的网络成为可能,而且性能依然优异。
本章主要介绍ResNet及其变种。
ResNet基础原理
ResNet的出现改变了VGG系列所带来的思维定势,创造性地采用模块化的思维替代整体的卷积层,通过一个个魔魁啊的堆叠来替代不断增加的卷积层。对ResNet的研究和不断改进成了过去几年中计算机视觉和深度学习领域最具有突破性的工作。并且由于其表征能力强,ResNet在图像分类任务以外许多计算机诗句应用上也取得了巨大的性能提升,例如目标检测和人脸识别。
ResNet诞生
卷积神经网络的实质就是无限拟合一个符合对应目标的函数。而根据泛逼近定理(Universal Approximation Theorem),如果给定足够的容量,一个单层的前馈网络就可以表示任何函数。但是,这个层可能非常大,而且网络容易过拟合数据。因此,学术界有一个共同的认识,就是网络架构需要更深。
但是,研究发现只是简单地将层堆叠在一起,增加网络的深度并不会起太大的作用。这是由于梯度消失(Vanishing Gradient)导致深层的网络很难训练。因为梯度反向传播到前一层,重复相乘可能是梯度无穷小,结果随着网络层数越深,其性能越趋于饱和,甚至开始迅速下降。
在ResNet之前,已经出现好几种处理梯度消失问题的方法,但是没有一个方法能够真正解决这个问题。何恺明等人与2015年发表的论文《用于图像识别的深度残差学习》(Deep Residual Learning for Image Recognition)中认为,堆叠的层不应该降低网络的性能,可以简单地在当前网络上堆叠映射层(不处理任务的层),并且所得到的架构性能不变。
当f(x)为0时,f’(x)等于x;当f(x)不为0时,所获得的f’(x)性能要优于单纯地输入x。公式表明,叫声的模型所产生的训练误差不应该比较浅的模型误差更高。假设让堆叠的层拟合一个残差映射(Residual Mapping),要比让它们直接拟合所需的底层映射更容易。
从下图可以看到,残差映射与传统直接相连的卷积神经网络相比,最大的变化就是加入了一个恒等映射层,即y = x层,其主要作用是使得网络随着深度的增加而不会产生权重衰减、梯度衰减或者消失这些问题。
上图中,F(x)表示残差,F(x) + x是最终的映射输出,因此可以得到网络的最终输出为H(x) = F(x) + x。由于网络框架中有2个卷积层和2个ReLU函数,因最终的输出结果可以表示为,
其中H_{1}是第一层的输出,而H_{2}是第二层输出。这样在输入与输出有相同维度时,可以使用直接输入的形式传递到框架的输出层。
ResNet整体结构图与VGG比较如下图所示,
上图是19层VGG、34层普通结构的神经网络以及24层ResNet神经网络的对比图。通过验证得知,在使用了ResNet的架构后发现,层数不断增加导致的训练集上误差增加的现象被消除了,ResNet网络的训练误差会随着层数增加而逐渐减小,并且在测试集上的表现也会更好。下图分别是18层和34层普通和残差网络训练后误差的下降曲线。
除了用于讲解的二层残差学习单元,实际上更多的是使用[1, 1]结构的三层残差学习单元,
这是借鉴了NiN(Network in Network)模型的思想,在二层残差单元中包含1个[3, 3]卷积层的基础上,更包含了2个[1, 1]大小的卷积,放在[3, 3]卷积层的前后,执行先降维再升维的操作。
无论采用哪种连接方式,ResNet的核心是引入一个“身份捷径连接”(Identify Shortcut Connection),直接跳过一层或多层将输入层与输出层进行连接。实际上,ResNet并不是第一个利用Identify Shortcut Connection的方法,较早期就有相关研究人员在卷积神经网络中引入了“门控短路电路”,即参数化的门控系统允许各种信息通过网络通道。如下图,
但并不是所有加入了“Shortcut”的卷积神经网络都会提高传输效果。在后续的研究中,有不少研究人员对残差块进行了改进,但是和遗憾,它并不能获得性能上的提高。
目前图(a) original性能最好。
ResNet架构组件
可能现在已经迫不及待想要自定义自己的残差网络了。在构建自己的残差网络之前,需要被准备好相关的程序设计工具。上一章介绍了jax.examle_libraries包,为了加深印象,ResNet最终的代码实战会使用这些JAX提供的组件,这些组件就是要使用的工具,即已经设计好结构并可以直接使用的代码。
- jax.exmaple_libraries.stax.Conv,卷积核。从模型上看,需要更改的内容很少,即卷积核的大小、输出通道数以及所定义的卷积层的名称。
- jax.example_libraries.stax.BatchNorm,对数据进行批标准化,这是使用批标准化对数据进行处理。
- jax.example_libraries.MaxPool,最大池化层。
- jax.example_libraries.AvgPool,平均池化层。
这些是在ResNet模型单元中所用到的基本工具,有了这些工具,就可以直接构建ResNet模型单元。下面将对实现函数进行详细介绍。
jax.exmaple_libraries.stax.Conv卷积层
jax.exmaple_libraries.stax.Conv是卷积计算的实现,其函数原型如下,
import functools
def GeneralConvolution(dimension_numbers, out_channel, filter_shape, strides = None, padding = "VALID", weights_init = None, biases_init = normal(1e-6)):
# ...
return init_fun, apply_fun
需要说明的是,卷积Convolution的实现在JAX中是挺贵2个步骤完成,第一步定义卷积主函数,也就是普通函数的卷积,自后对主函数进行格式化包装,生成符合计算需求的函数计算主体部分。
下面用一个例子说明jax.exmaple_libraries.stax.Conv的使用方法,
import jax
def model():
filter_number = 64
filter_size = (3, 3)
strides = (2, 2)
jax.example_libraries.stax.Conv(filter_number, filter_size, strides)
jax.example_libraries.stax.Conv(filter_number, filter_size, strides, padding = "SAME")
def train():
model()
def main():
train()
jax.example_libraries.stax.BatchNorm批标准化
BatchNormalization是目前最常见的数据标准化方法,也是批量标准化方法。输入数据经过处理之后能够显著地加速训练速度,并且减少过拟合出现的可能性。其函数原型如下所示,
def BatchNormalization(axis = (0, 1, 2), epsilon = 1e-5, center = True, sca>
...
return init_fun, apply_fun
BatchNormalization在jax.example_libraries.stax.BatchNorm调用时比较简单,直接初始化,一般不用传递参数。
jax.example_libraries.stax.Dense全连接层
Dense是全连接层,其在使用时需要输入分类的类别数,如下所示,
def Dense(out_dimensions, weights_init = glorot_normal(), biases_init = nromal(>
...
return init_fun, apply_fun
其中out_dimensions需要在类被初始化的时候定义,如下所示,
def Dense(out_dimensions = 10, weights_init = glorot_normal(), biases_init = nromal(>
...
return init_fun, apply_fun
Pooling池化层
Pooling即池化层。stax模块包含了多个池化方法,这几个池化方法都是类似的。包含jax.example_libraries.stax.MaxPool、jax.example_libraries.stax.SumPool、jax.example_libraries.staxAvgPool,分别代表最大、求和和平均池化方法。 下面以常用的jax.example_libraries.stax.AvgPool为例进行解释,
def AvgPool(window_shape, strides = None, padding = "VALID", spec = None):
...
return init_fun, apply_fun
可以看到,该方法需要输入3个参数,分别是池化窗口大小windw_shape、池化步进步长strides以及填充方式padding。
import jax
def model():
filter_number = 64
filter_size = (3, 3)
strides = (2, 2)
jax.example_libraries.stax.Conv(filter_number, filter_size, strides)
jax.example_libraries.stax.Conv(filter_number, filter_size, strides, padding = "SAME")
window_shape = (3, 3)
strides = (2, 2)
jax.example_libraries.stax.AvgPool(window_shape = window_shape, strides = strides)
def train():
model()
def main():
train()
除了上面这写谈及的类,还有其他构成神经网络的类,感兴趣可以自行尝试。
另外,在此说明,不同版本的JAX,其命名空间可能不同,比如早前版本jax.experimental.stax,截止到目前是jax.example_libraries.stax。
jax.example_libraries.stax特殊的类
下面介绍jax.example_libraries.stax一些特殊类。
jax.example_libraries.stax.FanOut数据复制
def FanOut(number):
init_fun = lambda prng, input_shape:([input_shape] * number, ())
apply_fun = lambda params, inputs, **kwargs: [inputs] * number
return init_fun, apply_fun
def model():
FanOut(number = 2)
从类构造函数函数原型可以看出,这个类是对输入的数据进行复制,接受一个参数number,代表复制的份数。
jax.example_libraries.stax.FanInSum数据求和
def FanInSum():
init_fun = lambda prng, input_shape: (input_shape[0], ())
apply_fun = lambda params, inputs, **kwargs: sum(inputs)
return init_fun, apply_fun
def model():
fanInsum = FanInSum()
从类构造函数函数原型可以看出,这个类是对输入的数据进行求和运算。
jax.example_libraries.stax.FanInConcat数据串接
import jax
def FanInConcat(axis = -1):
def init_fun(prng, input_shape):
ax = axis &% len(input_shape[0])
concat_size = sum(shape[ax] for shape in input_shape)
output_shape = input_shape[0][: ax] + (concat_size, ) + input_shape[0][>
return output_shape, ()
def apply_fun(params, inputs, **kwargs):
return jax.numpy.concatenate(inputs, axis)
return init_fun, apply_fun
从类构造函数函数原型可以看出,这个类是对输入的数据在最后一维进行串接(Concatenate),从而形成一个新的数据。
jax.example_libraries.stax.Identifty
def Identity():
init_fun = lambda prng, inputt_shape: (input_shape, ())
apply_fun = lambda params, inputs, **kwargs: inputs
return init_fun, apply_fun
本类的作用时对输入的数据进行完整的输出。
结论
本章简单介绍了ResNet的背景及原理,以及相对于普通卷积神经网络ResNet的改进。另外也承接上一章,详细介绍了JAX本身提供的jax.example_libraries.stax命名空间下深度学习基本组件,以及jax.example_libraries.stax下特殊类。
本章是理论准备阶段,为最终的代码实战做好理论准备。