初始化#

Note

初始化参数时需打破神经元的对称性
Xavier初始化和Kaiming初始化的intuition:在正向传播或反向传播时,保证输入和输出的量级基本相等

对称性#

想象一下,如果我们将层 \(l\) 的所有参数初始化为 \(\mathbf{W}^{[l]}=c_{l}, \mathbf{b}^{[l]} = d_{l}\) , 其中 \(c_{l},d_{l}\) 为常量,会发生什么情况:

正向和反向传播时,层 \(l\) 的所有神经元都是对称的,也就是说层 \(l\) 实际上相当于只有一个神经元。

Xavier 初始化#

考虑一个没有 bias term 和激活函数的全连接层:

\[o_{i} = \sum_{j=1}^{\text{fan}_{in}}w_{ij}x_{j}\]

假设 \(w_{ij}\) 的均值为0,方差为 \(\sigma^{2}\); \(x_{j}\) 的均值为0,方差为 \(\gamma^{2}\); \(w_{ij}, x_{j}\) 相互独立。那么 \(o_{i}\) 的均值显然为0,它的方差为:

\[\begin{split} \begin{equation} \begin{split} \text{Var}[o_{i}] =& E[o_{i}^{2}] - (E[o_{i}])^{2}\\ =&\sum_{j=1}^{\text{fan}_{in}}E[w_{ij}^{2}x_{j}^{2}] \\ =&\sum_{j=1}^{\text{fan}_{in}}E[w_{ij}^{2}]E[x_{j}^{2}] \\ =&\text{fan}_{in}\sigma^{2}\gamma^{2} \end{split} \end{equation} \end{split}\]

为了保证输入和输出的方差不变,需要 \(\text{fan}_{in}\sigma^{2}=1\)

另一方面,在反向传播中, 我们有:

\[\frac{\partial L}{\partial x_{j}} = \sum_{i=1}^{\text{fan}_{out}}w_{ij}\frac{\partial L}{\partial o_{i}}\]

类似地,为了维持梯度的方差,需要 \(\text{fan}_{out}\sigma^{2} = 1\)

\(\text{fan}_{in} \ne \text{fan}_{out}\) 上述两式不能同时成立, pytorch的实现里取了折中:

\[\frac{1}{2}(\text{fan}_{in} + \text{fan}_{out})\sigma^{2} = 1 \ \text{ or }\ \sigma = \sqrt{\frac{2}{\text{fan}_{in} + \text{fan}_{out}}}\]
import torch
from torch import nn

w = torch.empty(3, 5)
# xavier_normal初始化
nn.init.xavier_normal_(w)
tensor([[-0.0332, -0.0838, -0.4253, -0.4172,  0.3161],
        [ 0.6112, -0.0679,  0.3572,  0.4847, -0.4388],
        [ 0.4275,  0.0573, -0.3956, -0.2125,  0.7110]])

Kaiming 初始化#

当激活函数为 \(\mbox{ReLU}\) 时:

\[o_{i} = \mbox{ReLU}(\sum_{j=1}^{\text{fan}_{in}}w_{ij}x_{j})\]

在与Xavier初始化一样的假设下,\(o_{i}\)的方差为Xavier初始化中\(o_{i}\)方差的一半(因为有一半为0):

\[\text{Var}(o_{i}) = \frac{1}{2}\text{fan}_{in}\sigma^{2}\gamma^{2}\]

保持正向传播时方差不变需 \(\text{fan}_{in}\sigma^{2}=2\)。保持梯度方差不变需 \(\text{fan}_{out}\sigma^{2}=2\)。pytorch中实现的Kaiming初始化:

\[\sigma = \sqrt{\frac{2}{\mbox{fan}_{mode}}}\]

其中 \(\text{fan}_{mode}\)\(\text{fan}_{in}\) or \(\text{fan}_{out}\).

w = torch.empty(3, 5)
# nn.Linear默认使用kaiming_uniform_,其中mode=`fan_in`,即优先稳定正向传播
nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
tensor([[ 0.7250, -1.0235,  0.1649, -0.2647,  0.9801],
        [ 0.5283,  0.9472,  0.4259,  0.5999, -0.2434],
        [-0.2071, -0.7180,  0.3021, -0.6649,  0.1248]])