对BN的理解

BN在网络中的位置和操作流程

引言

机器学习有一个重要假设:IID,就是假设训练数据和测试数据是满足相同分布的,BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布的。为什么对输入数据做BN,原因在于神经网络学习过程本质上是为了学习数据的分布。

“Internal Covariate Shift”问题:

内部协变量偏移,Internal指的是网络深层的隐层,Covariate(协变量:不可控,但对结果有重要影响)指的是网络的参数。在训练过程中,因为各层参数不停在变化,导致隐层的输入分布老是变来变去。

BN的基本思想:

每个隐层节点的激活输入分布固定下来,避免了“Internal Covariate Shift”问题了,顺带解决反向传播中梯度消失问题。BN思路来源于:如果对图像做白化操作(0均值,1方差的正态分布),神经网络收敛较快,深度神经网络的每一个隐层都是输入层,不过是相对下一层来说而已,BN可以理解为对深层神经网络每个隐层神经元的激活值做简化版本的白化操作。

一句话:对于每个隐层神经元,把逐渐向非线性函数映射后向取值区间极限饱和区靠拢的输入分布强制拉回到均值为0方差为1的比较标准的正态分布,使得非线性变换函数的输入值落入对输入比较敏感的区域,以此避免梯度消失问题。经过BN后,目前大部分Activation的值落入非线性函数的线性区内,其对应的导数远离导数饱和区,这样来加速训练收敛过程。

疑点:BN操作之后,非线性激活函数变成了和线性函数一样的效果,显然是不行的,为了保证非线性的获得,对变换后的满足均值为0方差为1的x又进行了scale加上shift操作(y=scale*x+shift),每个神经元增加了两个参数scale和shift参数,这两个参数是通过训练学习到的,意思是通过scale和shift把这个值从标准正态分布左移或者右移一点并长胖一点或者变瘦一点,每个实例挪动的程度不一样,这样等价于非线性函数的值从正中心周围的线性区往非线性区动了动。这样找到一个线性和非线性的较好的平衡点,既能享受非线性的较强表达能力的好处,又避免太靠非线性区两头使得网络收敛速度太慢。这里理想状态的scale和shift操作会不会又把x变换成未变换之前的状态,又回到Internal Covariate Shift问题哪里?应该不会哈哈哈,否则BN完全没用了啊,事实证明。

Inference时的BN操作:

一个实例是没法算实例集合求出的均值和方差,既然没有从Mini-Batch数据里可以得到的统计量,那就想其它办法来获得这个统计量,就是均值和方差。可以用从所有训练实例中获得的统计量来代替Mini-Batch里面m个训练实例获得的均值和方差统计量,因为本来就打算用全局的统计量,只是因为计算量等太大所以才会用Mini-Batch这种简化方式的,那么在推理的时候直接用全局统计量即可。把每个Mini-Batch的均值和方差统计量记住,然后对这些均值和方差求其对应的数学期望即可得出全局统计量。设置model.eval()的一个作用就是固定BN层,不像在训练阶段去求每个mini-batch的均值方差,而是直接取出之前记录在网络里面的每个mini-batch的方差,去求期望.

个人理解

为什么bs越大越好,因为bs越大,每个bs的分布就越趋近于同分布,这样网络比较容易学习数据的分布规律,梯度更新方向比较一致,收敛更快。

BN中的参数

看一个例子

import torch
import torch.nn as nn
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(6)


    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)

        return x

model = Net()
for name, para in model.named_parameters():
    print(name, para)

print('************************************************************')
for name, buffer in model.named_buffers():
    print(name, buffer)

输出为

OrderedDict([('conv1.weight', tensor([[[[ 0.0108,  0.1240,  0.0641],
          [ 0.0838,  0.0657,  0.0785],
          [ 0.0755, -0.1763, -0.0934]],

         [[-0.1210, -0.1455, -0.1416],
          [ 0.0903,  0.0632,  0.0489],
          [-0.0614, -0.1614,  0.1625]],

         [[ 0.1661, -0.0992, -0.1398],
          [ 0.1170,  0.1084,  0.1536],
          [ 0.0179,  0.1310, -0.0289]]],


        [[[ 0.1363,  0.1840,  0.1140],
          [ 0.0471,  0.0555,  0.1758],
          [-0.0386,  0.1077,  0.1612]],

         [[ 0.1177,  0.1799, -0.0495],
          [-0.0314, -0.1714,  0.1125],
          [-0.0723, -0.0770,  0.1663]],

         [[-0.1474,  0.0866, -0.0111],
          [ 0.1476, -0.0468, -0.0683],
          [ 0.0535,  0.1440,  0.1900]]],


        [[[-0.0954,  0.0743, -0.0975],
          [ 0.0741,  0.1436, -0.1203],
          [-0.0047,  0.1317, -0.1513]],

         [[-0.1422,  0.1404,  0.1614],
          [ 0.0025, -0.1499,  0.1647],
          [ 0.0192,  0.0324,  0.0593]],

         [[-0.0041,  0.1813, -0.1696],
          [ 0.0822,  0.1765, -0.1627],
          [ 0.0262,  0.1857, -0.0359]]],


        [[[-0.1816, -0.1198, -0.1289],
          [-0.0138,  0.1118, -0.0687],
          [-0.0078, -0.0975, -0.0646]],

         [[ 0.1763, -0.0490, -0.1117],
          [ 0.0976, -0.0156,  0.1104],
          [-0.0755,  0.0067,  0.0637]],

         [[-0.0131, -0.1783,  0.0628],
          [ 0.1020,  0.1713, -0.0764],
          [-0.1752,  0.0589, -0.0661]]],


        [[[-0.0292,  0.1491,  0.1690],
          [-0.1483,  0.1089, -0.1463],
          [-0.1159,  0.0097,  0.1525]],

         [[-0.0439, -0.0683, -0.0691],
          [-0.0465, -0.0289,  0.1653],
          [ 0.1307, -0.0170, -0.1875]],

         [[-0.0941,  0.1616,  0.0168],
          [ 0.1385,  0.1919,  0.0238],
          [-0.0705,  0.1550,  0.1585]]],


        [[[ 0.1091,  0.0602, -0.1886],
          [ 0.0663,  0.1151, -0.1629],
          [ 0.0955, -0.1370, -0.1030]],

         [[-0.1690,  0.1786,  0.0723],
          [-0.0280, -0.0451, -0.0303],
          [-0.0342, -0.0909, -0.1883]],

         [[ 0.1072,  0.1869,  0.0249],
          [ 0.1028, -0.1043,  0.0852],
          [-0.0532, -0.1132, -0.0372]]]])), ('conv1.bias', tensor([-0.0907,  0.1700, -0.0342,  0.1511,  0.0931,  0.0797])), ('bn1.weight', tensor([1., 1., 1., 1., 1., 1.])), ('bn1.bias', tensor([0., 0., 0., 0., 0., 0.])), ('bn1.running_mean', tensor([0., 0., 0., 0., 0., 0.])), ('bn1.running_var', tensor([1., 1., 1., 1., 1., 1.])), ('bn1.num_batches_tracked', tensor(0))])

conv1.weight Parameter containing:
tensor([[[[-0.1410,  0.0936, -0.0152],
          [-0.1397, -0.1212, -0.1048],
          [-0.1421, -0.0171,  0.0640]],

         [[ 0.1423, -0.1203, -0.0369],
          [-0.0067,  0.0966,  0.1195],
          [ 0.0143,  0.0839, -0.0283]],

         [[-0.1537, -0.1123, -0.1345],
          [ 0.0886,  0.1017,  0.0533],
          [-0.0084, -0.1251,  0.1744]]],


        [[[ 0.1859, -0.1693, -0.1616],
          [ 0.0567,  0.1256,  0.0887],
          [-0.0761, -0.1245, -0.0764]],

         [[ 0.1298, -0.1307, -0.0978],
          [ 0.0780,  0.0860, -0.0598],
          [-0.0295, -0.1884,  0.0191]],

         [[-0.1898, -0.0489,  0.1485],
          [-0.1887, -0.0618, -0.1429],
          [ 0.1066, -0.0593,  0.0559]]],


        [[[ 0.0189,  0.0575,  0.1358],
          [-0.1079, -0.0591, -0.1221],
          [ 0.0100, -0.0392,  0.0423]],

         [[ 0.1072,  0.1461, -0.1267],
          [-0.1478,  0.1647,  0.1149],
          [ 0.0258, -0.1862, -0.0070]],

         [[ 0.1138, -0.0968,  0.0016],
          [-0.0955,  0.1802,  0.0822],
          [-0.1311,  0.0945, -0.0038]]],


        [[[ 0.1647, -0.0404,  0.0610],
          [-0.1558,  0.1357,  0.1779],
          [-0.0070,  0.1030, -0.0585]],

         [[ 0.1592,  0.0970,  0.0614],
          [-0.0068, -0.0732,  0.1352],
          [ 0.0447,  0.0769, -0.0384]],

         [[-0.0589, -0.0711, -0.0543],
          [ 0.0926, -0.0984, -0.0573],
          [ 0.0687,  0.1849,  0.0993]]],


        [[[ 0.0730,  0.0036,  0.0584],
          [ 0.0568,  0.0311, -0.1742],
          [ 0.1582, -0.0496, -0.0620]],

         [[ 0.0348, -0.1415,  0.0212],
          [-0.1688,  0.0436, -0.1485],
          [ 0.0154, -0.1302,  0.1255]],

         [[ 0.1393,  0.0575, -0.1821],
          [ 0.0244, -0.1584,  0.0886],
          [-0.0158, -0.1907, -0.1038]]],


        [[[ 0.0019, -0.0077, -0.0073],
          [ 0.0667,  0.1904,  0.1622],
          [-0.1315,  0.1265,  0.0110]],

         [[ 0.0979,  0.0211, -0.1126],
          [ 0.1260,  0.1614,  0.0309],
          [-0.0724, -0.1381,  0.1275]],

         [[-0.0206, -0.0674, -0.0358],
          [-0.0800, -0.0408,  0.1636],
          [ 0.0082, -0.0014, -0.0292]]]], requires_grad=True)
conv1.bias Parameter containing:
tensor([-0.0660,  0.0184,  0.0102,  0.1804, -0.0702,  0.0977],
       requires_grad=True)
bn1.weight Parameter containing:
tensor([1., 1., 1., 1., 1., 1.], requires_grad=True)
bn1.bias Parameter containing:
tensor([0., 0., 0., 0., 0., 0.], requires_grad=True)
************************************************************
bn1.running_mean tensor([0., 0., 0., 0., 0., 0.])
bn1.running_var tensor([1., 1., 1., 1., 1., 1.])
bn1.num_batches_tracked tensor(0)

可以看到,网络中的参数除了parameters,还有一些不用更新的参数,主要是BN中的'bn1.running_mean'bn1.running_var,这些参数只在forward时进行统计计算,backward时并不会被更新,这些参数也称为buffer,可以用model.buffers()获取。顺便提一下,在进行推理时设置model.val(),会固定这些参数,不会计算,而是采用记录的全局统计量,如上所述。

创建于2020.11.26

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,311评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,339评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,671评论 0 342
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,252评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,253评论 5 371
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,031评论 1 285
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,340评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,973评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,466评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,937评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,039评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,701评论 4 323
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,254评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,259评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,485评论 1 262
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,497评论 2 354
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,786评论 2 345

推荐阅读更多精彩内容