多头注意力¶
Note
在实践中,当给定相同的查询、键和值的集合时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来。
模型¶
多头注意力首先用独立学习得到的 \(h\) 组不同的线性投影(linear projections)来变换查询、键和值。
然后这 \(h\) 组变换后的查询、键和值将并行地送到注意力汇聚中。
最后将这 \(h\) 个注意力汇聚拼接在一起,经过另一个线性投影产生最终的输出。
让我们用数学的语言将这个模型描述出来。给定查询 \(\mathbf{q}\in\mathbb{R}^{d_q}\)、键 \(\mathbf{k}\in\mathbb{R}^{d_k}\) 和值 \(\mathbf{v}\in\mathbb{R}^{d_v}\),每个注意力头 \(\mathbf{h}_{i}(i=1,...,h)\) 的计算方法为:
其中可学习的参数包括 \(\mathbf{W}_{i}^{(q)}\in\mathbb{R}^{p_{q}\times{d_{q}}}\) , \(\mathbf{W}_{i}^{(k)}\in\mathbb{R}^{p_{k}\times{d_{k}}}\) , \(\mathbf{W}_{i}^{(v)}\in\mathbb{R}^{p_{v}\times{d_{v}}}\) 和注意汇聚函数 \(f\),\(f\) 可以是加性注意力或是缩放点积注意力。
最后把 \(h\) 个头连接后的进行线性变换:
其中可学习的参数是 \(\mathbf{W}_{o}\in\mathbb{R}^{p_{o}\times{h{p_v}}}\)。
实现¶
在实现过程中,我们使用缩放点积注意力作为每一个注意力头。为简单起见,我们设定 \(p_{q}=p_{k}=p_{v}=\frac{p_{o}}{h}\)。
import torch
from torch import nn
import d2l
#@save
class MultiHeadAttention(nn.Module):
"""多头注意力"""
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
# `num_heads`个线性变换拼接起来,所以`num_hiddens`应可以整除`num_heads`
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
# Shape of `queries`, `keys`, or `values`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`)
# Shape of `valid_lens`:
# (`batch_size`,) or (`batch_size`, no. of queries)
# After transposing, shape of output `queries`, `keys`, or `values`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs, `num_hiddens` / `num_heads`)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
valid_lens = torch.repeat_interleave(valid_lens,
repeats=self.num_heads, dim=0)
# Shape of `output`: (`batch_size` * `num_heads`, no. of queries, `num_hiddens` / `num_heads`)
output = self.attention(queries, keys, values, valid_lens)
# Shape of `output_concat`: (`batch_size`, no. of queries, `num_hiddens`)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
#@save
def transpose_qkv(X, num_heads):
"""改变X的shape"""
# Shape of input `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`).
# Shape of output `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_heads`,
# `num_hiddens` / `num_heads`)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# Shape of output `X`:
# (`batch_size`, `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
X = X.permute(0, 2, 1, 3)
# Shape of `output`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
return X.reshape(-1, X.shape[2], X.shape[3])
#@save
def transpose_output(X, num_heads):
"""逆转`transpose_qkv`的操作"""
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
自注意力¶
给定一个词元组成的输入序列 \(\mathbf{x}_{1},...,\mathbf{x}_{n}\),其中 \(\mathbf{x}_{i}\in\mathbb{R}^{d}\)。
改序列的自注意力输出一个长度相同的序列 \(\mathbf{y}_{1},...,\mathbf{y}_{n}\),其中:
num_hiddens, num_heads = 100, 5
# 这里 `d` = `num_hiddens`
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=100, out_features=100, bias=False)
(W_k): Linear(in_features=100, out_features=100, bias=False)
(W_v): Linear(in_features=100, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
# 形状相同
attention(X, X, X, valid_lens).shape
torch.Size([2, 4, 100])