Transformer的其它组件#
Note
下图是transformer的结构图,我们已经描述并实现了其中的多头注意力(Multi-head attention)和位置编码(Positional encoding)。
本节我们来讲tranformer的另外的组件:基于位置的前馈网络(Positionwise FFN)、残差连接和层归一化(Add & norm)。
基于位置的前馈网络#
即对序列中所有位置的表示进行变换时,使用的是同一个多层感知机(MLP)。
import torch
from torch import nn
#@save
class PositionWiseFFN(nn.Module):
"""基于位置的前馈网络"""
def __init__(self, ffn_num_input, ffn_num_hiddens):
super(PositionWiseFFN, self).__init__()
self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
self.relu = nn.ReLU()
self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_input)
def forward(self, X):
# X shape: (`batch_size`, `num_steps`, `ffn_num_input`)
# 输入和输出的形状一样
return self.dense2(self.relu(self.dense1(X)))
残差连接和层归一化#
此组件由残差连接和紧随其后的层归一化组成,两者都是构建有效的深度结构的关键。
\[\mathrm{LayerNorm}(x + \mathrm{SubLayer}(x))\]
层归一化和批量归一化(Batch Normalization)的目标相同,但层归一化的均值和方差在最后几个维度上进行计算。在自然语言处理任务中批量归一化通常不如层归一化效果好。
2D输入#
# 输入num_features
ln = nn.LayerNorm(3)
bn = nn.BatchNorm1d(3)
# shape: (`batch_size`, num_features)
X = torch.tensor([[1, 2, 3], [4, 6, 8]], dtype=torch.float32)
# 层归一化计算每个样本的均值和方差
ln(X), (X - X.mean(axis=1).reshape(-1, 1)) / X.std(axis=1, unbiased=False).reshape(-1, 1)
(tensor([[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247]], grad_fn=<NativeLayerNormBackward0>),
tensor([[-1.2247, 0.0000, 1.2247],
[-1.2247, 0.0000, 1.2247]]))
# 批量归一化计算每个特征的均值和方差
bn(X), (X - X.mean(axis=0)) / X.std(axis=0, unbiased=False)
(tensor([[-1.0000, -1.0000, -1.0000],
[ 1.0000, 1.0000, 1.0000]], grad_fn=<NativeBatchNormBackward0>),
tensor([[-1., -1., -1.],
[ 1., 1., 1.]]))
3D输入#
batch, num_features, sentence_length = 2, 3, 4
# 注意这里和 NLP 中常用的 (batch, sentence_length, num_features) 不同,这是为了满足 BatchNorm1d 的输入格式
X = torch.randn(batch, num_features, sentence_length)
# 输入除batch外的维度
layer3d = nn.LayerNorm([3, 4])
# 层归一化计算每个样本的均值和方差
layer3d(X), (X - X.mean(axis=[1, 2]).reshape(-1, 1, 1)) / X.std(axis=[1, 2], unbiased=False).reshape(-1, 1, 1)
(tensor([[[ 1.5608, 1.2644, 0.1408, 0.2393],
[-0.6965, -1.3973, -1.5091, 0.1032],
[ 1.0444, 0.8340, -0.5457, -1.0384]],
[[-0.0166, -1.9082, 0.0902, -1.1285],
[ 1.2292, -1.0437, 1.2204, 1.2079],
[-0.2656, -0.5193, 1.0890, 0.0453]]],
grad_fn=<NativeLayerNormBackward0>),
tensor([[[ 1.5609, 1.2644, 0.1408, 0.2393],
[-0.6965, -1.3973, -1.5091, 0.1032],
[ 1.0444, 0.8340, -0.5457, -1.0384]],
[[-0.0166, -1.9082, 0.0902, -1.1286],
[ 1.2292, -1.0437, 1.2204, 1.2079],
[-0.2656, -0.5193, 1.0890, 0.0453]]]))
# 只聚合一个维度
layerLast = nn.LayerNorm(4)
layerLast(X), (X - X.mean(axis=[2], keepdims=True)) / X.std(axis=[2], unbiased=False, keepdims=True)
(tensor([[[ 1.2227, 0.7455, -1.0633, -0.9048],
[ 0.2767, -0.8100, -0.9834, 1.5167],
[ 1.0956, 0.8582, -0.6989, -1.2549]],
[[ 0.8769, -1.4135, 1.0062, -0.4695],
[ 0.5875, -1.7320, 0.5786, 0.5658],
[-0.5768, -0.9914, 1.6369, -0.0687]]],
grad_fn=<NativeLayerNormBackward0>),
tensor([[[ 1.2227, 0.7455, -1.0633, -0.9048],
[ 0.2767, -0.8100, -0.9834, 1.5167],
[ 1.0956, 0.8582, -0.6989, -1.2549]],
[[ 0.8769, -1.4136, 1.0062, -0.4695],
[ 0.5875, -1.7320, 0.5786, 0.5658],
[-0.5769, -0.9914, 1.6370, -0.0687]]]))
# batchNorm1d适用于输入为2d和3d的情况,它计算每个feature的均值和方差
bn(X), (X - X.mean(axis=[0, 2]).reshape(1, -1, 1)) / X.std(axis=[0, 2], unbiased=False).reshape(1, -1, 1)
(tensor([[[ 1.4244, 1.1652, 0.1826, 0.2687],
[-0.4652, -1.1100, -1.2129, 0.2706],
[ 1.3618, 1.0893, -0.6976, -1.3357]],
[[-0.1105, -1.8076, -0.0147, -1.1081],
[ 1.1729, -0.9728, 1.1646, 1.1528],
[-0.5735, -0.9105, 1.2265, -0.1603]]],
grad_fn=<NativeBatchNormBackward0>),
tensor([[[ 1.4244, 1.1652, 0.1826, 0.2687],
[-0.4652, -1.1100, -1.2129, 0.2706],
[ 1.3618, 1.0893, -0.6976, -1.3357]],
[[-0.1105, -1.8076, -0.0147, -1.1081],
[ 1.1729, -0.9728, 1.1646, 1.1528],
[-0.5735, -0.9105, 1.2265, -0.1603]]]))
实现#
#@save
class AddNorm(nn.Module):
"""残差连接和层归一化"""
def __init__(self, normalized_shape, dropout):
super(AddNorm, self).__init__()
self.dropout = nn.Dropout(dropout)
# normalized_shape指定均值和方差计算的维度,需是后几个维度
self.ln = nn.LayerNorm(normalized_shape)
def forward(self, X, Y):
# 先残差连接,再层归一化
return self.ln(self.dropout(Y) + X)
add_norm = AddNorm([3, 4], 0.5)
add_norm.eval()
# 形状不变
add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape
torch.Size([2, 3, 4])