Batch Normalization#

Note

同正则化、Dropout一样,Batch Normalization(BN)也是一种处理深度神经网络过拟合的方法
Batch Normalization可以加速网络收敛,让我们能够训练更深的网络

Intuition#

回想第一节我们用线性回归预测房价时,第一步是标准化输入特征,类似地,神经网络也需要标准化输入特征,其实不仅是输入层,中间各层也需要标准化

此外更深层的网络非常复杂容易过拟合,这就有正则的需求

Batch Normalization其实就是按批进行标准化,再统一拉伸和偏移(有量纲和偏移量灵活度的标准化):

\[\mathrm{BN}(\mathbf{x}) = \boldsymbol{\gamma}\odot\frac{\mathbf{x} - \hat{\boldsymbol{\mu}}_{\mathcal{B}}}{\hat{\boldsymbol{\sigma}}_{\mathcal{B}}} + \boldsymbol{\beta}\]

其中 \(\hat{\boldsymbol{\mu}}_{\mathcal{B}}\)\(\hat{\boldsymbol{\sigma}}_{\mathcal{B}}\) 分别是小批量 \(\mathcal{B}\) 的均值和标准差;拉伸参数 \(\boldsymbol{\gamma}\) 和偏移参数 \(\boldsymbol{\beta}\)\(\mathbf{x}\) 的形状相同,是模型需要学习的参数。

均值和标准差的计算公式如下:

\[\hat{\boldsymbol{\mu}}_{\mathcal{B}} = \frac{1}{|\mathcal{B}|}\sum_{x\in{\mathcal{B}}}\mathbf{x}\]
\[\hat{\boldsymbol{\sigma}}_{\mathcal{B}}^{2} = \frac{1}{|\mathcal{B}|}\sum_{x\in{\mathcal{B}}}(\mathbf{x} - \hat{\boldsymbol{\mu}}_{\mathcal{B}})^{2} + \boldsymbol{\epsilon}\]

Intuition2#

深度神经网络中,顶部的层因为梯度较大更新较快,底部的层因为梯度消失更新较慢。所以训练时,顶部的层一般会很快收敛,但是底部的层收敛很慢。

这里就有一个问题了,顶部的层依赖于底部的层,所以底部的参数一变化,顶部就算之前已经收敛还得重新训练,然后顶部的参数变化也会在反向传播时影响底部,这样来来回回训练不好。

固定小批量里的均值和方差。

  • 对于全连接层,作用在特征维

  • 对于卷积层,作用在通道维(相当于1乘1卷积的特征维)

每个小批量里加入噪音来控制模型复杂度。

Pytorch中的Batch Normalization#

import torch
from torch import nn

net = nn.Sequential(
    # BN一般在全连接层之后,激活函数之前,需指定输入的维度
    # CNN也有其对应的BN层即BatchNorm2d,后面会讲
    nn.Linear(784, 100), nn.BatchNorm1d(100), nn.ReLU(),
    nn.Linear(100, 64), nn.BatchNorm1d(64), nn.ReLU(),
    nn.Linear(64, 10))

Warning

Batch Normalization和Dropout一起使用的话 1+1 < 1,所以不要一起使用,二者选其一就行
Batch Normalization也有正则的效果,所以一般不和正则化一起使用

bn = nn.BatchNorm2d(3)
x = torch.randn(2, 3, 4, 5)
x
tensor([[[[-0.7056, -0.4590,  0.6550, -0.8764,  0.1595],
          [-0.3366,  1.2398,  0.7763, -0.0936, -0.0803],
          [-0.6050, -1.6816, -0.3381,  1.2039,  0.1787],
          [ 1.2616,  0.1024, -0.3560,  1.8771, -0.3419]],

         [[-0.0999, -0.5607,  1.3607, -1.6053,  0.8459],
          [-0.4745, -0.2351,  0.2903, -0.2562, -1.3345],
          [ 1.0218, -2.7007, -0.1659, -1.6627, -1.1253],
          [-0.4410,  0.7307, -0.4006, -0.8461, -1.1328]],

         [[ 0.3236, -0.2399, -1.3959,  1.1676, -1.8621],
          [-0.1137,  0.5725,  0.5061,  0.2277,  0.3814],
          [-1.3717,  1.3927, -0.4582,  0.6749, -1.3643],
          [ 0.2869,  1.1865, -0.5557, -1.3239,  0.0709]]],


        [[[-1.0764,  0.4783,  0.5798,  0.1085,  1.2515],
          [ 1.1440, -1.0778, -0.9636,  0.2656,  0.3380],
          [-0.2073,  1.3440, -0.1970,  1.6990, -1.2503],
          [-1.0565, -0.5901,  0.8860, -0.0596, -1.2087]],

         [[-1.8812, -1.3341, -0.9922,  0.5816,  0.1440],
          [ 1.0714, -0.7039, -0.1289, -2.0139,  0.1285],
          [ 0.5642,  0.4715,  0.5909, -1.2737, -0.1398],
          [ 1.1680,  0.1156,  0.0612, -0.5378, -0.1655]],

         [[-0.7684,  0.6300, -0.5697, -0.8114,  1.2653],
          [-1.3524, -0.3176,  1.0327, -1.2614, -0.0778],
          [-0.3558,  0.0546,  0.3996, -0.1144, -0.2044],
          [-0.4393,  0.3843,  0.2791,  2.0548, -0.8331]]]])
bn(x)
tensor([[[[-0.8553, -0.5760,  0.6854, -1.0486,  0.1244],
          [-0.4374,  1.3476,  0.8227, -0.1623, -0.1472],
          [-0.7413, -1.9604, -0.4391,  1.3069,  0.1461],
          [ 1.3723,  0.0597, -0.4594,  2.0693, -0.4435]],

         [[ 0.2413, -0.2491,  1.7957, -1.3607,  1.2478],
          [-0.1574,  0.0974,  0.6566,  0.0750, -1.0726],
          [ 1.4350, -2.5264,  0.1710, -1.4218, -0.8499],
          [-0.1217,  1.1252, -0.0787, -0.5528, -0.8579]],

         [[ 0.4480, -0.1893, -1.4967,  1.4025, -2.0239],
          [-0.0466,  0.7294,  0.6543,  0.3396,  0.5134],
          [-1.4693,  1.6570, -0.4362,  0.8453, -1.4609],
          [ 0.4065,  1.4239, -0.5465, -1.4152,  0.1622]]],


        [[[-1.2751,  0.4853,  0.6003,  0.0665,  1.3609],
          [ 1.2391, -1.2767, -1.1474,  0.2444,  0.3265],
          [-0.2910,  1.4656, -0.2793,  1.8676, -1.4720],
          [-1.2526, -0.7245,  0.9470, -0.1237, -1.4249]],

         [[-1.6543, -1.0721, -0.7083,  0.9666,  0.5009],
          [ 1.4878, -0.4014,  0.2104, -1.7956,  0.4844],
          [ 0.9480,  0.8493,  0.9765, -1.0078,  0.1988],
          [ 1.5906,  0.4707,  0.4127, -0.2247,  0.1714]],

         [[-0.7870,  0.7944, -0.5624, -0.8357,  1.5129],
          [-1.4475, -0.2772,  1.2499, -1.3446, -0.0060],
          [-0.3204,  0.1438,  0.5339, -0.0474, -0.1491],
          [-0.4148,  0.5167,  0.3977,  2.4058, -0.8602]]]],
       grad_fn=<NativeBatchNormBackward0>)
# 每个通道一个均值和方差
(x - x.mean(axis=[0, 2, 3], keepdims=True)) / x.std(axis=[0, 2, 3], unbiased=False, keepdims=True)
tensor([[[[-0.8553, -0.5760,  0.6854, -1.0486,  0.1244],
          [-0.4374,  1.3476,  0.8227, -0.1623, -0.1472],
          [-0.7413, -1.9604, -0.4391,  1.3069,  0.1461],
          [ 1.3723,  0.0597, -0.4594,  2.0693, -0.4435]],

         [[ 0.2413, -0.2491,  1.7957, -1.3607,  1.2478],
          [-0.1574,  0.0974,  0.6566,  0.0750, -1.0726],
          [ 1.4351, -2.5264,  0.1710, -1.4218, -0.8499],
          [-0.1217,  1.1252, -0.0787, -0.5528, -0.8579]],

         [[ 0.4480, -0.1893, -1.4967,  1.4025, -2.0240],
          [-0.0466,  0.7294,  0.6543,  0.3396,  0.5134],
          [-1.4693,  1.6570, -0.4362,  0.8453, -1.4610],
          [ 0.4065,  1.4239, -0.5465, -1.4152,  0.1622]]],


        [[[-1.2751,  0.4853,  0.6003,  0.0665,  1.3609],
          [ 1.2391, -1.2767, -1.1474,  0.2444,  0.3265],
          [-0.2910,  1.4656, -0.2793,  1.8676, -1.4720],
          [-1.2526, -0.7245,  0.9470, -0.1237, -1.4249]],

         [[-1.6543, -1.0721, -0.7083,  0.9666,  0.5009],
          [ 1.4878, -0.4014,  0.2104, -1.7956,  0.4844],
          [ 0.9480,  0.8493,  0.9765, -1.0078,  0.1988],
          [ 1.5906,  0.4707,  0.4127, -0.2247,  0.1714]],

         [[-0.7870,  0.7945, -0.5624, -0.8357,  1.5129],
          [-1.4475, -0.2772,  1.2499, -1.3446, -0.0060],
          [-0.3204,  0.1438,  0.5339, -0.0474, -0.1491],
          [-0.4148,  0.5167,  0.3977,  2.4058, -0.8602]]]])