Transformer的其它组件#

Note

下图是transformer的结构图,我们已经描述并实现了其中的多头注意力(Multi-head attention)和位置编码(Positional encoding)。
本节我们来讲tranformer的另外的组件:基于位置的前馈网络(Positionwise FFN)、残差连接和层归一化(Add & norm)。

jupyter

基于位置的前馈网络#

即对序列中所有位置的表示进行变换时,使用的是同一个多层感知机(MLP)。

import torch
from torch import nn


#@save
class PositionWiseFFN(nn.Module):
    """基于位置的前馈网络"""
    def __init__(self, ffn_num_input, ffn_num_hiddens):
        super(PositionWiseFFN, self).__init__()
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_input)

    def forward(self, X):
        # X shape: (`batch_size`, `num_steps`, `ffn_num_input`)
        # 输入和输出的形状一样
        return self.dense2(self.relu(self.dense1(X)))

残差连接和层归一化#

此组件由残差连接和紧随其后的层归一化组成,两者都是构建有效的深度结构的关键。

\[\mathrm{LayerNorm}(x + \mathrm{SubLayer}(x))\]

层归一化和批量归一化(Batch Normalization)的目标相同,但层归一化的均值和方差在最后几个维度上进行计算。在自然语言处理任务中批量归一化通常不如层归一化效果好。

2D输入#

# 输入num_features
ln = nn.LayerNorm(3)
bn = nn.BatchNorm1d(3)
# shape: (`batch_size`, num_features)
X = torch.tensor([[1, 2, 3], [4, 6, 8]], dtype=torch.float32)
# 层归一化计算每个样本的均值和方差
ln(X), (X - X.mean(axis=1).reshape(-1, 1)) / X.std(axis=1, unbiased=False).reshape(-1, 1)
(tensor([[-1.2247,  0.0000,  1.2247],
         [-1.2247,  0.0000,  1.2247]], grad_fn=<NativeLayerNormBackward0>),
 tensor([[-1.2247,  0.0000,  1.2247],
         [-1.2247,  0.0000,  1.2247]]))
# 批量归一化计算每个特征的均值和方差
bn(X), (X - X.mean(axis=0)) / X.std(axis=0, unbiased=False)
(tensor([[-1.0000, -1.0000, -1.0000],
         [ 1.0000,  1.0000,  1.0000]], grad_fn=<NativeBatchNormBackward0>),
 tensor([[-1., -1., -1.],
         [ 1.,  1.,  1.]]))

3D输入#

batch, num_features, sentence_length = 2, 3, 4
# 注意这里和 NLP 中常用的 (batch, sentence_length, num_features) 不同,这是为了满足 BatchNorm1d 的输入格式
X = torch.randn(batch, num_features, sentence_length)
# 输入除batch外的维度
layer3d = nn.LayerNorm([3, 4])
# 层归一化计算每个样本的均值和方差
layer3d(X), (X - X.mean(axis=[1, 2]).reshape(-1, 1, 1)) / X.std(axis=[1, 2], unbiased=False).reshape(-1, 1, 1)
(tensor([[[ 1.5608,  1.2644,  0.1408,  0.2393],
          [-0.6965, -1.3973, -1.5091,  0.1032],
          [ 1.0444,  0.8340, -0.5457, -1.0384]],
 
         [[-0.0166, -1.9082,  0.0902, -1.1285],
          [ 1.2292, -1.0437,  1.2204,  1.2079],
          [-0.2656, -0.5193,  1.0890,  0.0453]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[[ 1.5609,  1.2644,  0.1408,  0.2393],
          [-0.6965, -1.3973, -1.5091,  0.1032],
          [ 1.0444,  0.8340, -0.5457, -1.0384]],
 
         [[-0.0166, -1.9082,  0.0902, -1.1286],
          [ 1.2292, -1.0437,  1.2204,  1.2079],
          [-0.2656, -0.5193,  1.0890,  0.0453]]]))
# 只聚合一个维度
layerLast = nn.LayerNorm(4)
layerLast(X), (X - X.mean(axis=[2], keepdims=True)) / X.std(axis=[2], unbiased=False, keepdims=True)
(tensor([[[ 1.2227,  0.7455, -1.0633, -0.9048],
          [ 0.2767, -0.8100, -0.9834,  1.5167],
          [ 1.0956,  0.8582, -0.6989, -1.2549]],
 
         [[ 0.8769, -1.4135,  1.0062, -0.4695],
          [ 0.5875, -1.7320,  0.5786,  0.5658],
          [-0.5768, -0.9914,  1.6369, -0.0687]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[[ 1.2227,  0.7455, -1.0633, -0.9048],
          [ 0.2767, -0.8100, -0.9834,  1.5167],
          [ 1.0956,  0.8582, -0.6989, -1.2549]],
 
         [[ 0.8769, -1.4136,  1.0062, -0.4695],
          [ 0.5875, -1.7320,  0.5786,  0.5658],
          [-0.5769, -0.9914,  1.6370, -0.0687]]]))
# batchNorm1d适用于输入为2d和3d的情况,它计算每个feature的均值和方差
bn(X), (X - X.mean(axis=[0, 2]).reshape(1, -1, 1)) / X.std(axis=[0, 2], unbiased=False).reshape(1, -1, 1)
(tensor([[[ 1.4244,  1.1652,  0.1826,  0.2687],
          [-0.4652, -1.1100, -1.2129,  0.2706],
          [ 1.3618,  1.0893, -0.6976, -1.3357]],
 
         [[-0.1105, -1.8076, -0.0147, -1.1081],
          [ 1.1729, -0.9728,  1.1646,  1.1528],
          [-0.5735, -0.9105,  1.2265, -0.1603]]],
        grad_fn=<NativeBatchNormBackward0>),
 tensor([[[ 1.4244,  1.1652,  0.1826,  0.2687],
          [-0.4652, -1.1100, -1.2129,  0.2706],
          [ 1.3618,  1.0893, -0.6976, -1.3357]],
 
         [[-0.1105, -1.8076, -0.0147, -1.1081],
          [ 1.1729, -0.9728,  1.1646,  1.1528],
          [-0.5735, -0.9105,  1.2265, -0.1603]]]))

实现#

#@save
class AddNorm(nn.Module):
    """残差连接和层归一化"""
    def __init__(self, normalized_shape, dropout):
        super(AddNorm, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # normalized_shape指定均值和方差计算的维度,需是后几个维度
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        # 先残差连接,再层归一化
        return self.ln(self.dropout(Y) + X)
add_norm = AddNorm([3, 4], 0.5)
add_norm.eval()
# 形状不变
add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape
torch.Size([2, 3, 4])