Transformer Encoder#
Note
Transformer作为编码器-解码器结构的一个实例,它完全基于注意力机制,核心是多头注意力,其整体结构图如下图所示。
前面几节,我们介绍了Transformer的各个组件,本节我们要用这些组件来构建 Transformer Encoder.
编码器 Block#
下面的EncoderBlock类实现了示意图左边的虚线框,它包含两个子层:多头注意力和基于位置的前馈网络,这两个子层都使用了残差连接和紧随的层归一化。
import torch
from torch import nn
import d2l
import math
class EncoderBlock(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_hiddens, num_heads,
dropout, use_bias=False):
super(EncoderBlock, self).__init__()
# key_size=query_size=value_size in Transformer
# 多头注意力
# num_hiddens应能整除num_heads,每个头的宽度为 num_hiddens//num_heads,W_o: num_hiddens -> num_hiddens
self.attention = d2l.MultiHeadAttention(key_size, query_size,
value_size, num_hiddens,
num_heads, dropout, use_bias)
# 第一个add&norm
self.addnorm1 = d2l.AddNorm(norm_shape, dropout)
# positionwiseFFN
self.ffn = d2l.PositionWiseFFN(num_hiddens, ffn_num_hiddens)
# 第二个add&norm
self.addnorm2 = d2l.AddNorm(norm_shape, dropout)
def forward(self, X, valid_lens):
# `X` shape: (`batch_size`, `num_steps`, `num_hiddens`)
# `valid_lens` shape: None or (`batch_size`,) or (`batch_size`, `num_steps`)
# 第一个子层
# 在attention后被mask的的位置正常计算,但除了layerNorm外都是每个位置独立计算
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
# 第二个子层,形状不变
return self.addnorm2(Y, self.ffn(Y))
编码器#
class TransformerEncoder(d2l.Encoder):
"""Transformer的编码器"""
def __init__(self, vocab_size, num_hiddens,
norm_shape, ffn_num_hiddens, num_heads,
num_layers, dropout, use_bias=False):
super(TransformerEncoder, self).__init__()
self.num_hiddens = num_hiddens
# Embedding将输入从`vocab_size`变为`num_hiddens`
self.embedding = nn.Embedding(vocab_size, num_hiddens)
# 位置编码
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
# 各个EncoderBlock
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module(
"block" + str(i),
EncoderBlock(num_hiddens, num_hiddens, num_hiddens, num_hiddens,
norm_shape, ffn_num_hiddens, num_heads,
dropout, use_bias))
def forward(self, X, valid_lens):
# X shape: (`batch_size`, `num_steps`, `vocab_size`)
# 因为位置编码值在-1到1之间,因此需要进行平方根缩放,保持它们在一个量级
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
# valid_lens在每个block都生效
for blk in self.blks:
X = blk(X, valid_lens)
return X