Transformer Encoder#

Note

Transformer作为编码器-解码器结构的一个实例,它完全基于注意力机制,核心是多头注意力,其整体结构图如下图所示。
前面几节,我们介绍了Transformer的各个组件,本节我们要用这些组件来构建 Transformer Encoder.

jupyter

编码器 Block#

下面的EncoderBlock类实现了示意图左边的虚线框,它包含两个子层:多头注意力和基于位置的前馈网络,这两个子层都使用了残差连接和紧随的层归一化。

import torch
from torch import nn
import d2l
import math


class EncoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_hiddens, num_heads, 
                 dropout, use_bias=False):
        super(EncoderBlock, self).__init__()
        # key_size=query_size=value_size in Transformer
        # 多头注意力
        # num_hiddens应能整除num_heads,每个头的宽度为 num_hiddens//num_heads,W_o: num_hiddens -> num_hiddens
        self.attention = d2l.MultiHeadAttention(key_size, query_size,
                                                value_size, num_hiddens,
                                                num_heads, dropout, use_bias)
        # 第一个add&norm
        self.addnorm1 = d2l.AddNorm(norm_shape, dropout)
        # positionwiseFFN
        self.ffn = d2l.PositionWiseFFN(num_hiddens, ffn_num_hiddens)
        # 第二个add&norm
        self.addnorm2 = d2l.AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        # `X` shape: (`batch_size`, `num_steps`, `num_hiddens`)
        # `valid_lens` shape: None or (`batch_size`,) or (`batch_size`, `num_steps`)
        # 第一个子层
        # 在attention后被mask的的位置正常计算,但除了layerNorm外都是每个位置独立计算
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        # 第二个子层,形状不变
        return self.addnorm2(Y, self.ffn(Y))

编码器#

class TransformerEncoder(d2l.Encoder):
    """Transformer的编码器"""
    def __init__(self, vocab_size, num_hiddens, 
                 norm_shape, ffn_num_hiddens, num_heads, 
                 num_layers, dropout, use_bias=False):
        super(TransformerEncoder, self).__init__()
        self.num_hiddens = num_hiddens
        # Embedding将输入从`vocab_size`变为`num_hiddens`
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        # 位置编码
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
        # 各个EncoderBlock
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(
                "block" + str(i),
                EncoderBlock(num_hiddens, num_hiddens, num_hiddens, num_hiddens,
                             norm_shape, ffn_num_hiddens, num_heads, 
                             dropout, use_bias))

    def forward(self, X, valid_lens):
        # X shape: (`batch_size`, `num_steps`, `vocab_size`)
        # 因为位置编码值在-1到1之间,因此需要进行平方根缩放,保持它们在一个量级
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        # valid_lens在每个block都生效
        for blk in self.blks:
            X = blk(X, valid_lens)
        return X