初始化

Note

初始化参数时需打破神经元的对称性
Xavier初始化和Kaiming初始化的intuition:在正向传播或反向传播时,保证输入和输出的量级基本相等

对称性

想象一下,如果我们将层 \(l\) 的所有参数初始化为 \(\mathbf{W}^{[l]}=c_{l}, \mathbf{b}^{[l]} = d_{l}\) , 其中 \(c_{l},d_{l}\) 为常量,会发生什么情况:

正向和反向传播时,层 \(l\) 的所有神经元都是对称的,也就是说层 \(l\) 实际上相当于只有一个神经元。

Xavier 初始化

考虑一个没有 bias term 和激活函数的全连接层:

\[o_{i} = \sum_{j=1}^{\text{fan}_{in}}w_{ij}x_{j}\]

假设 \(w_{ij}\) 的均值为0,方差为 \(\sigma^{2}\); \(x_{j}\) 的均值为0,方差为 \(\gamma^{2}\); \(w_{ij}, x_{j}\) 相互独立。那么 \(o_{i}\) 的均值显然为0,它的方差为:

\[\begin{split} \begin{equation} \begin{split} \text{Var}[o_{i}] =& E[o_{i}^{2}] - (E[o_{i}])^{2}\\ =&\sum_{j=1}^{\text{fan}_{in}}E[w_{ij}^{2}x_{j}^{2}] \\ =&\sum_{j=1}^{\text{fan}_{in}}E[w_{ij}^{2}]E[x_{j}^{2}] \\ =&\text{fan}_{in}\sigma^{2}\gamma^{2} \end{split} \end{equation} \end{split}\]

为了保证输入和输出的方差不变,需要 \(\text{fan}_{in}\sigma^{2}=1\)

另一方面,在反向传播中, 我们有:

\[\frac{\partial L}{\partial x_{j}} = \sum_{i=1}^{\text{fan}_{out}}w_{ij}\frac{\partial L}{\partial o_{i}}\]

类似地,为了维持梯度的方差,需要 \(\text{fan}_{out}\sigma^{2} = 1\)

\(\text{fan}_{in} \ne \text{fan}_{out}\) 上述两式不能同时成立, pytorch的实现里取了折中:

\[\frac{1}{2}(\text{fan}_{in} + \text{fan}_{out})\sigma^{2} = 1 \ \text{ or }\ \sigma = \sqrt{\frac{2}{\text{fan}_{in} + \text{fan}_{out}}}\]
import torch
from torch import nn

w = torch.empty(3, 5)
# xavier_normal初始化
nn.init.xavier_normal_(w)
tensor([[ 0.0330, -0.1627,  0.5484, -0.6934, -0.0037],
        [ 0.2257,  0.4591, -0.5610, -0.0438, -0.0834],
        [ 0.0437,  0.2729, -0.2554,  0.0250,  0.6984]])

Kaiming 初始化

当激活函数为 \(\mbox{ReLU}\) 时:

\[o_{i} = \mbox{ReLU}(\sum_{j=1}^{\text{fan}_{in}}w_{ij}x_{j})\]

在与Xavier初始化一样的假设下,\(o_{i}\)的方差为Xavier初始化中\(o_{i}\)方差的一半(因为有一半为0):

\[\text{Var}(o_{i}) = \frac{1}{2}\text{fan}_{in}\sigma^{2}\gamma^{2}\]

保持正向传播时方差不变需 \(\text{fan}_{in}\sigma^{2}=2\)。保持梯度方差不变需 \(\text{fan}_{out}\sigma^{2}=2\)。pytorch中实现的Kaiming初始化:

\[\sigma = \sqrt{\frac{2}{\mbox{fan}_{mode}}}\]

其中 \(\text{fan}_{mode}\)\(\text{fan}_{in}\) or \(\text{fan}_{out}\).

w = torch.empty(3, 5)
# nn.Linear默认使用kaiming_uniform_,其中mode=`fan_in`,即优先稳定正向传播
nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
tensor([[-0.1288,  0.2462,  0.2680, -0.2140, -0.5077],
        [ 1.0324, -0.8032,  0.0058,  0.2864, -0.8681],
        [-0.3520, -1.0235, -0.3865,  0.1485, -0.6039]])