Batch Normalization#
Note
同正则化、Dropout一样,Batch Normalization(BN)也是一种处理深度神经网络过拟合的方法
Batch Normalization可以加速网络收敛,让我们能够训练更深的网络
Intuition#
回想第一节我们用线性回归预测房价时,第一步是标准化输入特征,类似地,神经网络也需要标准化输入特征,其实不仅是输入层,中间各层也需要标准化
此外更深层的网络非常复杂容易过拟合,这就有正则的需求
Batch Normalization其实就是按批进行标准化,再统一拉伸和偏移(有量纲和偏移量灵活度的标准化):
其中 \(\hat{\boldsymbol{\mu}}_{\mathcal{B}}\) 和 \(\hat{\boldsymbol{\sigma}}_{\mathcal{B}}\) 分别是小批量 \(\mathcal{B}\) 的均值和标准差;拉伸参数 \(\boldsymbol{\gamma}\) 和偏移参数 \(\boldsymbol{\beta}\) 和 \(\mathbf{x}\) 的形状相同,是模型需要学习的参数。
均值和标准差的计算公式如下:
Intuition2#
深度神经网络中,顶部的层因为梯度较大更新较快,底部的层因为梯度消失更新较慢。所以训练时,顶部的层一般会很快收敛,但是底部的层收敛很慢。
这里就有一个问题了,顶部的层依赖于底部的层,所以底部的参数一变化,顶部就算之前已经收敛还得重新训练,然后顶部的参数变化也会在反向传播时影响底部,这样来来回回训练不好。
固定小批量里的均值和方差。
对于全连接层,作用在特征维
对于卷积层,作用在通道维(相当于1乘1卷积的特征维)
每个小批量里加入噪音来控制模型复杂度。
Pytorch中的Batch Normalization#
import torch
from torch import nn
net = nn.Sequential(
# BN一般在全连接层之后,激活函数之前,需指定输入的维度
# CNN也有其对应的BN层即BatchNorm2d,后面会讲
nn.Linear(784, 100), nn.BatchNorm1d(100), nn.ReLU(),
nn.Linear(100, 64), nn.BatchNorm1d(64), nn.ReLU(),
nn.Linear(64, 10))
Warning
Batch Normalization和Dropout一起使用的话 1+1 < 1,所以不要一起使用,二者选其一就行
Batch Normalization也有正则的效果,所以一般不和正则化一起使用
bn = nn.BatchNorm2d(3)
x = torch.randn(2, 3, 4, 5)
x
tensor([[[[-0.7056, -0.4590, 0.6550, -0.8764, 0.1595],
[-0.3366, 1.2398, 0.7763, -0.0936, -0.0803],
[-0.6050, -1.6816, -0.3381, 1.2039, 0.1787],
[ 1.2616, 0.1024, -0.3560, 1.8771, -0.3419]],
[[-0.0999, -0.5607, 1.3607, -1.6053, 0.8459],
[-0.4745, -0.2351, 0.2903, -0.2562, -1.3345],
[ 1.0218, -2.7007, -0.1659, -1.6627, -1.1253],
[-0.4410, 0.7307, -0.4006, -0.8461, -1.1328]],
[[ 0.3236, -0.2399, -1.3959, 1.1676, -1.8621],
[-0.1137, 0.5725, 0.5061, 0.2277, 0.3814],
[-1.3717, 1.3927, -0.4582, 0.6749, -1.3643],
[ 0.2869, 1.1865, -0.5557, -1.3239, 0.0709]]],
[[[-1.0764, 0.4783, 0.5798, 0.1085, 1.2515],
[ 1.1440, -1.0778, -0.9636, 0.2656, 0.3380],
[-0.2073, 1.3440, -0.1970, 1.6990, -1.2503],
[-1.0565, -0.5901, 0.8860, -0.0596, -1.2087]],
[[-1.8812, -1.3341, -0.9922, 0.5816, 0.1440],
[ 1.0714, -0.7039, -0.1289, -2.0139, 0.1285],
[ 0.5642, 0.4715, 0.5909, -1.2737, -0.1398],
[ 1.1680, 0.1156, 0.0612, -0.5378, -0.1655]],
[[-0.7684, 0.6300, -0.5697, -0.8114, 1.2653],
[-1.3524, -0.3176, 1.0327, -1.2614, -0.0778],
[-0.3558, 0.0546, 0.3996, -0.1144, -0.2044],
[-0.4393, 0.3843, 0.2791, 2.0548, -0.8331]]]])
bn(x)
tensor([[[[-0.8553, -0.5760, 0.6854, -1.0486, 0.1244],
[-0.4374, 1.3476, 0.8227, -0.1623, -0.1472],
[-0.7413, -1.9604, -0.4391, 1.3069, 0.1461],
[ 1.3723, 0.0597, -0.4594, 2.0693, -0.4435]],
[[ 0.2413, -0.2491, 1.7957, -1.3607, 1.2478],
[-0.1574, 0.0974, 0.6566, 0.0750, -1.0726],
[ 1.4350, -2.5264, 0.1710, -1.4218, -0.8499],
[-0.1217, 1.1252, -0.0787, -0.5528, -0.8579]],
[[ 0.4480, -0.1893, -1.4967, 1.4025, -2.0239],
[-0.0466, 0.7294, 0.6543, 0.3396, 0.5134],
[-1.4693, 1.6570, -0.4362, 0.8453, -1.4609],
[ 0.4065, 1.4239, -0.5465, -1.4152, 0.1622]]],
[[[-1.2751, 0.4853, 0.6003, 0.0665, 1.3609],
[ 1.2391, -1.2767, -1.1474, 0.2444, 0.3265],
[-0.2910, 1.4656, -0.2793, 1.8676, -1.4720],
[-1.2526, -0.7245, 0.9470, -0.1237, -1.4249]],
[[-1.6543, -1.0721, -0.7083, 0.9666, 0.5009],
[ 1.4878, -0.4014, 0.2104, -1.7956, 0.4844],
[ 0.9480, 0.8493, 0.9765, -1.0078, 0.1988],
[ 1.5906, 0.4707, 0.4127, -0.2247, 0.1714]],
[[-0.7870, 0.7944, -0.5624, -0.8357, 1.5129],
[-1.4475, -0.2772, 1.2499, -1.3446, -0.0060],
[-0.3204, 0.1438, 0.5339, -0.0474, -0.1491],
[-0.4148, 0.5167, 0.3977, 2.4058, -0.8602]]]],
grad_fn=<NativeBatchNormBackward0>)
# 每个通道一个均值和方差
(x - x.mean(axis=[0, 2, 3], keepdims=True)) / x.std(axis=[0, 2, 3], unbiased=False, keepdims=True)
tensor([[[[-0.8553, -0.5760, 0.6854, -1.0486, 0.1244],
[-0.4374, 1.3476, 0.8227, -0.1623, -0.1472],
[-0.7413, -1.9604, -0.4391, 1.3069, 0.1461],
[ 1.3723, 0.0597, -0.4594, 2.0693, -0.4435]],
[[ 0.2413, -0.2491, 1.7957, -1.3607, 1.2478],
[-0.1574, 0.0974, 0.6566, 0.0750, -1.0726],
[ 1.4351, -2.5264, 0.1710, -1.4218, -0.8499],
[-0.1217, 1.1252, -0.0787, -0.5528, -0.8579]],
[[ 0.4480, -0.1893, -1.4967, 1.4025, -2.0240],
[-0.0466, 0.7294, 0.6543, 0.3396, 0.5134],
[-1.4693, 1.6570, -0.4362, 0.8453, -1.4610],
[ 0.4065, 1.4239, -0.5465, -1.4152, 0.1622]]],
[[[-1.2751, 0.4853, 0.6003, 0.0665, 1.3609],
[ 1.2391, -1.2767, -1.1474, 0.2444, 0.3265],
[-0.2910, 1.4656, -0.2793, 1.8676, -1.4720],
[-1.2526, -0.7245, 0.9470, -0.1237, -1.4249]],
[[-1.6543, -1.0721, -0.7083, 0.9666, 0.5009],
[ 1.4878, -0.4014, 0.2104, -1.7956, 0.4844],
[ 0.9480, 0.8493, 0.9765, -1.0078, 0.1988],
[ 1.5906, 0.4707, 0.4127, -0.2247, 0.1714]],
[[-0.7870, 0.7945, -0.5624, -0.8357, 1.5129],
[-1.4475, -0.2772, 1.2499, -1.3446, -0.0060],
[-0.3204, 0.1438, 0.5339, -0.0474, -0.1491],
[-0.4148, 0.5167, 0.3977, 2.4058, -0.8602]]]])