【转载】理解Batch Normalization

原文链接:https://zhuanlan.zhihu.com/p/43200897

Batch Normalization(简称BN)自从提出之后,因为效果特别好,很快被作为深度学习的标准工具应用在了各种场合。BN大法虽然好,但是也存在一些局限和问题,诸如当BatchSize太小时效果不佳、对RNN等动态网络无法有效应用BN等。针对BN的问题,最近两年又陆续有基于BN思想的很多改进Normalization模型被提出。BN是深度学习进展中里程碑式的工作之一,无论是希望深入了解深度学习,还是在实践中解决实际问题,BN及一系列改进Normalization工作都是绕不开的重要环节。

一. 从Mini-Batch SGD说起

我们先从Mini-Batch SGD的优化过程讲起,因为这是下一步理解Batch Normalization中Batch所代表具体含义的知识基础。

我们知道,SGD是无论学术圈写文章做实验还是工业界调参跑模型最常用的模型优化算法,但是有时候容易被忽略的一点是:一般提到的SGD是指的Mini-batch SGD,而非原教旨意义下的单实例SGD。

Mini-Batch SGD训练过程(假设Batch Size=2)

所谓“Mini-Batch”,是指的从训练数据全集T中随机选择的一个训练数据子集合。假设训练数据集合T包含N个样本,而每个Mini-Batch的Batch Size为b,于是整个训练数据可被分成N/b个Mini-Batch。在模型通过SGD进行训练时,一般跑完一个Mini-Batch的实例,叫做完成训练的一步(step),跑完N/b步则整个训练数据完成一轮训练,则称为完成一个Epoch。完成一个Epoch训练过程后,对训练数据做随机Shuffle打乱训练数据顺序,重复上述步骤,然后开始下一个Epoch的训练,对模型完整充分的训练由多轮Epoch构成(参考图1)。

在拿到一个Mini-Batch进行参数更新时,首先根据当前Mini-Batch内的b个训练实例以及参数对应的损失函数的偏导数来进行计算,以获得参数更新的梯度方向,然后根据SGD算法进行参数更新,以此来达到本步(Step)更新模型参数并逐步寻优的过程。

图2,Mini-Batch SGD优化过程

具体而言,如果我们假设机器学习任务的损失函数是平方损失函数:

那么,由Mini-Batch内训练实例可得出SGD优化所需的梯度方向为:

其中,是Mini-Batch内第i个训练实例对应的输入和值。是希望学习到的映射函数,其中是函数对应的当前参数值。代表了Mini-Batch中实例i决定的梯度方向,Batch内所有训练实例共同确定了本次参数更新的梯度方向。

根据梯度方向即可利用标准SGD来更新模型参数:
\theta = \theta - \eta \Delta\theta
其中,\eta是学习率。

由上述过程(参考图2)可以看出,对于Mini-Batch SGD训练方法来说,为了能够参数更新,必须先求出梯度方向,而且了能够求出梯度方向,需要对每个实例得出当前参数下映射函数的预测值h_\theta(x_i),这意味着如果是用神经网络来学习映射函数h(x)的话,Mini-Batch内的每个实例需要走一遍当前的网络,产生当前参数下神经网络的预测值,这点请注意,这是理解后续Batch Normalization的基础。

至于Batch Size的影响,目前可以实验证实的是:batch size 设置得较小训练出来的模型相对大batch size训练出的模型泛化能力更强,在测试集上的表现更好,而太大的batch size往往不太Work,而且泛化能力较差。但是背后是什么原因造成的,目前还未有定论,持不同看法者各持己见。因为这不是文本的重点,所以先略过不表。

二、Normalization到底是什么

Normalization的中文翻译一般叫做“规范化”,是一种对数值的特殊函数变换方法,也就是说假设原始的某个数值是x,套上一个起到规范作用的函数,对规范化之前的数值x进行转换,形成一个规范化后的数值。
\hat{x}=f(x)
所谓规范化,是希望转换后的数值\hat{x}满足一定的特性,至于对数值具体如何变换,跟规范化目标有关,也就是说f()函数的具体形式,不同的规范化目标导致具体方法中函数所采用的形式不同。

图3. 神经元

在介绍深度学习Normalization前,我们先普及下神经元的活动过程。深度学习是由神经网络来体现对输入数据的函数变换的,二神经网络的基础单元就是网络神经元,一个典型的神经元对数据进行处理时包含两个步骤的操作(参考图3):
步骤一:对输入数据进行线性变换,产生净激活值

其中,x是输入,w是权重参数,b是偏置,w和b是需要进过训练学习的网络参数。
步骤二:套上非线性激活函数,神经网络的非线性能力来自于此,目前深度学习最常用的激活函数是Relu函数.

如此一个神经元就完成了对输入数据的非线性函数变换。这里需要强调下,步骤一的输出一般称为净激活(Net Activation),第二步骤经过激活函数后得到的值为激活值。为了描述简洁,本文后续文字中使用激活的地方,其实指的是未经激活函数的净激活值,而非一般意义上的激活,这点还请注意。

至于深度学习中的Normalization,因为神经网络里主要有两类实体:神经元或者连接神经元的边,所以按照规范化操作涉及对象的不同可以分为两大类,一类是对第L层每个神经元的激活值或者说对于第L+1层网络神经元的输入值进行Normalization操作,比如BatchNorm/LayerNorm/InstanceNorm/GroupNorm等方法都属于这一类;另外一类是对神经网络中连接相邻隐层神经元之间的边上的权重进行规范化操作,比如Weight Norm就属于这一类。广义上讲,一般机器学习里看到的损失函数里面加入的对参数的的L1/L2等正则项,本质上也属于这第二类规范化操作。L1正则的规范化目标是造成参数的稀疏化,就是争取达到让大量参数值取得0值的效果,而L2正则的规范化目标是有效减小原始参数值的大小。有了这些规范目标,通过具体的规范化手段来改变参数值,以达到避免模型过拟合的目的。

本文主要介绍第一类针对神经元的规范化操作方法,这是目前DNN做Normalization最主流的做法。

图4. Normalizaition加入的位置

那么对于第一类的Normalization操作,其在什么位置发挥作用呢?
目前有两种在神经元中插入Normalization操作的地方(参考图4),

  • 第一种是原始BN论文提出的,放在激活函数之前;
  • 另外一种是后续研究提出的,放在激活函数之后,不少研究表明将BN放在激活函数之后效果更好。

对于神经元的激活值来说,不论哪种Normalization方法,其规范化目标都是一样的,就是将其激活值规整为均值为0,方差为1的正态分布。即规范化函数统一都是如下形式:



写成两步的模式是为了方便讲解,如果写成一体的形式,则是如下形式:


其中,a_i 为某个神经元原始激活值,a^{norm}_i为经过规范化操作后的规范后值。整个规范化过程可以分解为两步,第一步参考公式(1),是对激活值规整到均值为0,方差为1的正态分布范围内。其中,\mu
是通过神经元集合S(至于S如何选取读者可以先不用关注,后文有述)中包含的m个神经元各自的激活值求出的均值,即:


为根据均值和集合S中神经元各自激活值求出的激活值标准差:

其中,是诶了增加训练稳定性二加入的晓得常量数据。

第二步参考公式(2),主要目标是让每个神经元在训练过程中学习到对应的两个调节因子,对规范到0均值,1方差的值进行微调。因为经过第一步操作后,Normalization有可能降低神经网络的非线性表达能力,所以会以此方式来补偿Normalization操作后神经网络的表达能力。

目前神经网络中常见的第一类Normalization方法包括Batch Normalization/Layer Normalization/Instance Normalization和Group Normalization,BN最早由Google研究人员于2015年提出,后面几个算法算是BN的改进版本。不论是哪个方法,其基本计算步骤都如上所述,大同小异,最主要的区别在于神经元集合S的范围怎么定,不同的方法采用了不同的神经元集合定义方法。

为什么这些Normalization需要确定一个神经元集合S呢?原因很简单,前面讲过,这类深度学习的规范化目标是将神经元的激活值a_i限定在均值为0,方差为1的正态分布中。而为了能够对网络中某个神经元的激活值规范到均值为0方差为1的范围,必须有一定的手段求出均值和方差,而均值和方差都是个统计指标,要计算这两个指标一定是在一个集合范围内才可行,所以这就要就必须指定一个神经元组成的集合,利用这个集合里每个神经元的激活来统计出所需的均值和方差,这样才能达到预定的规范化目标。

图5. Normalization具体例子

图5给出了这类Normalization的一个计算过程的具体例子,例子中假设网络结构是前向反馈网络,对于隐层的三个节点来说,其原初的激活值为[0.4,-0.6,0.7],为了可以计算均值为0方差为1的正态分布,划定集合S中包含了这个网络中的6个神经元,至于如何划定集合S读者可以先不用关心,此时其对应的激活值如图中所示,根据这6个激活值,可以算出对应的均值和方差。有了均值和方差,可以利用公式3对原初激活值进行变换,如果r和b被设定为1,那么可以得到转换后的激活值[0.21,-0.75,0.50],对于新的激活值经过非线性变换函数比如RELU,则形成这个隐层的输出值[0.21,0,0.50]。这个例子中隐层的三个神经元在某刻进行Normalization计算的时候共用了同一个集合S,在实际的计算中,隐层中的神经元可能共用同一个集合,也可能每个神经元采用不同的神经元集合S,并非一成不变,这点还请留心与注意。

针对神经元的所有Normalization方法都遵循上述计算过程,唯一的不同在于如何划定计算统计量所需的神经元集合S上。读者可以自己思考下,如果你是BN或者其它改进模型的设计者,那么你会如何选取集合S?

三、Batch Normalization如何做?

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,378评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,356评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,702评论 0 342
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,259评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,263评论 5 371
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,036评论 1 285
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,349评论 3 400
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,979评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,469评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,938评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,059评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,703评论 4 323
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,257评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,262评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,485评论 1 262
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,501评论 2 354
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,792评论 2 345

推荐阅读更多精彩内容