Transformer Decoder#
Note
本节我们使用之前介绍过的组件来搭建 Transformer Decoder.
解码器 Block#
如结构图所示,Transformer的解码器也是由多个结构相同的层组成。
DecoderBlock包含三个子层:解码器自注意力、”编码器-解码器”注意力和基于位置的前馈网络。
在遮蔽多头自注意力层(Masked multi-head attention,第一层)中,查询、键和值都来自于上一个解码器层的输出。在训练阶段,输出序列所有时间步的词元都是已知的;然而,在预测阶段,其输出序列的词元是逐个生成的。
import torch
from torch import nn
import d2l
import math
class DecoderBlock(nn.Module):
"""解码器中的第i个块"""
def __init__(self, key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_hiddens, num_heads,
dropout, i):
super(DecoderBlock, self).__init__()
self.i = i
self.attention1 = d2l.MultiHeadAttention(key_size, query_size,
value_size, num_hiddens,
num_heads, dropout)
self.addnorm1 = d2l.AddNorm(norm_shape, dropout)
self.attention2 = d2l.MultiHeadAttention(key_size, query_size,
value_size, num_hiddens,
num_heads, dropout)
self.addnorm2 = d2l.AddNorm(norm_shape, dropout)
self.ffn = d2l.PositionWiseFFN(num_hiddens, ffn_num_hiddens)
self.addnorm3 = d2l.AddNorm(norm_shape, dropout)
def forward(self, X, state):
# 训练阶段 `X` shape: (`batch_size`, `num_steps`, `num_hiddens`)
# 预测阶段 `X` shape: (`batch_size`, 1, `num_hiddens`)
# enc_outputs来自编码器(即其最后一个编码器block的输出)shape (`batch_size`, `num_steps`, `num_hiddens`)
# enc_valid_lens也来编码器
enc_outputs, enc_valid_lens = state[0], state[1]
# `state[2][self.i]` 用于预测阶段,初始化为None,它存储截止目前时间步的的输出序列
# 训练和第一个token的预测
if state[2][self.i] is None:
key_values = X
# 后续预测
else:
# 跟RNN-seq2seq不一样,Transformer预测要用到截止目前的输出序列,而不只是上一时间步的输出
# key_values shape: (`batch_size`, `cur_steps`, `num_hiddens`)
key_values = torch.cat((state[2][self.i], X), axis=1)
state[2][self.i] = key_values
if self.training:
batch_size, num_steps, _ = X.shape
# 防作弊
# shape of dec_valid_lens: (`batch_size`, `num_steps`)
# 其中每一行是 [1, 2, ..., `num_steps`]
dec_valid_lens = torch.arange(1, num_steps + 1,
device=X.device).repeat(batch_size, 1)
else:
# 预测时token by token就不用了
dec_valid_lens = None
# Self-attention
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1(X, X2)
# Encoder-decoder attention
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2(Y, Y2)
return self.addnorm3(Z, self.ffn(Z)), state
解码器#
class TransformerDecoder(d2l.Decoder):
"""Transformer解码器"""
def __init__(self, vocab_size, num_hiddens,
norm_shape, ffn_num_hiddens, num_heads,
num_layers, dropout):
super(TransformerDecoder, self).__init__()
self.num_hiddens = num_hiddens
self.num_layers = num_layers
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
# 各个DecoderBlock
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module(
"block" + str(i),
DecoderBlock(num_hiddens, num_hiddens, num_hiddens, num_hiddens,
norm_shape, ffn_num_hiddens, num_heads,
dropout, i))
self.dense = nn.Linear(num_hiddens, vocab_size)
def init_state(self, enc_outputs, enc_valid_lens, *args):
# 给state[2]留位置
return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
def forward(self, X, state):
# 常规操作
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
for i, blk in enumerate(self.blks):
# state[0]和state[1]存储编码器的信息
# state[2]用于预测,用来存储截止目前时间步各个block的输出序列
X, state = blk(X, state)
return self.dense(X), state