Transformer Decoder#

Note

本节我们使用之前介绍过的组件来搭建 Transformer Decoder.

jupyter

解码器 Block#

如结构图所示,Transformer的解码器也是由多个结构相同的层组成。

DecoderBlock包含三个子层:解码器自注意力、”编码器-解码器”注意力和基于位置的前馈网络。

在遮蔽多头自注意力层(Masked multi-head attention,第一层)中,查询、键和值都来自于上一个解码器层的输出。在训练阶段,输出序列所有时间步的词元都是已知的;然而,在预测阶段,其输出序列的词元是逐个生成的。

import torch
from torch import nn
import d2l
import math


class DecoderBlock(nn.Module):
    """解码器中的第i个块"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_hiddens, num_heads, 
                 dropout, i):
        super(DecoderBlock, self).__init__()
        self.i = i
        self.attention1 = d2l.MultiHeadAttention(key_size, query_size,
                                                 value_size, num_hiddens,
                                                 num_heads, dropout)
        self.addnorm1 = d2l.AddNorm(norm_shape, dropout)
        self.attention2 = d2l.MultiHeadAttention(key_size, query_size,
                                                 value_size, num_hiddens,
                                                 num_heads, dropout)
        self.addnorm2 = d2l.AddNorm(norm_shape, dropout)
        self.ffn = d2l.PositionWiseFFN(num_hiddens, ffn_num_hiddens)
        self.addnorm3 = d2l.AddNorm(norm_shape, dropout)

    def forward(self, X, state):
        # 训练阶段 `X` shape: (`batch_size`, `num_steps`, `num_hiddens`)
        # 预测阶段 `X` shape: (`batch_size`, 1, `num_hiddens`)
        # enc_outputs来自编码器(即其最后一个编码器block的输出)shape (`batch_size`, `num_steps`, `num_hiddens`)
        # enc_valid_lens也来编码器
        enc_outputs, enc_valid_lens = state[0], state[1]
        
        # `state[2][self.i]` 用于预测阶段,初始化为None,它存储截止目前时间步的的输出序列
        # 训练和第一个token的预测
        if state[2][self.i] is None:
            key_values = X
        # 后续预测
        else:
            # 跟RNN-seq2seq不一样,Transformer预测要用到截止目前的输出序列,而不只是上一时间步的输出
            # key_values shape: (`batch_size`, `cur_steps`, `num_hiddens`)
            key_values = torch.cat((state[2][self.i], X), axis=1)
        state[2][self.i] = key_values
        
        if self.training:
            batch_size, num_steps, _ = X.shape
            # 防作弊
            # shape of dec_valid_lens: (`batch_size`, `num_steps`)
            # 其中每一行是 [1, 2, ..., `num_steps`]
            dec_valid_lens = torch.arange(1, num_steps + 1, 
                                          device=X.device).repeat(batch_size, 1)
        else:
            # 预测时token by token就不用了
            dec_valid_lens = None
        
        # Self-attention
        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)
        # Encoder-decoder attention
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state

解码器#

class TransformerDecoder(d2l.Decoder):
    """Transformer解码器"""
    def __init__(self, vocab_size, num_hiddens, 
                 norm_shape, ffn_num_hiddens, num_heads, 
                 num_layers, dropout):
        super(TransformerDecoder, self).__init__()
        self.num_hiddens = num_hiddens
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
        # 各个DecoderBlock
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(
                "block" + str(i),
                DecoderBlock(num_hiddens, num_hiddens, num_hiddens, num_hiddens,
                             norm_shape, ffn_num_hiddens, num_heads, 
                             dropout, i))
        self.dense = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, enc_outputs, enc_valid_lens, *args):
        # 给state[2]留位置
        return [enc_outputs, enc_valid_lens, [None] * self.num_layers]

    def forward(self, X, state):
        # 常规操作
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        for i, blk in enumerate(self.blks):
            # state[0]和state[1]存储编码器的信息
            # state[2]用于预测,用来存储截止目前时间步各个block的输出序列
            X, state = blk(X, state)
        return self.dense(X), state