注意力评分函数¶
假设有一个查询 \(\mathbf{q}\in\mathbb{R}^{q}\) 和 \(m\) 个键-值对 \((\mathbf{k}_{1}, \mathbf{v}_{1}),...,(\mathbf{k}_{m}, \mathbf{v}_{m})\),其中 \(\mathbf{k}_{i}\in\mathbb{R}^{k}\) ,\(\mathbf{v}_{i}\in\mathbb{R}^{v}\) 。注意力汇聚函数就被表示成值的加权和:
注意力评分函数 \(a\) 将查询 \(\mathbf{q}\) 和键 \(\mathbf{k}_{i}\) 两个向量映射成了标量 \(a(\mathbf{q}, \mathbf{k}_{i})\),表示键对值的注意力。
有很多种不同的注意力评分函数,本节介绍其中较流行的两种:加性注意力(additive attention)和缩放点积注意力(scaled dot-product attention)。
带遮蔽的softmax¶
正如上面公式中所示,softmax运算用于输出一个概率分布作为注意力权重。
但是在很多时候,并非所有的值都应被纳入注意力汇聚中,比如说文本序列中的填充词元。
下面的函数实现了这样的遮蔽softmax。
import torch
from torch import nn
import d2l
#@save
def masked_softmax(X, valid_lens):
"""实现带遮蔽的softmax"""
# shape of X: (`batch_size`, no. of queries, no. of key-value pairs)
# shape of valid_lens: None or (`batch_size`,) or (`batch_size`, no. of queries)
if valid_lens is None:
return nn.functional.softmax(X, dim=-1)
else:
shape = X.shape
# 将valid_lens转化为(`batch_size` * no. of queries)
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# 在最后的轴上,遮蔽的元素被替换成一个非常大的负值,其指数约为0
X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
value=-1e6)
return nn.functional.softmax(X.reshape(shape), dim=-1)
加性注意力¶
一般来说,当查询和键是不同长度的矢量时,可以通过加性注意力作为评分函数。
给定查询 \(\mathbf{q} \in \mathbb{R}^{q}\) 和键 \(\mathbf{k} \in \mathbb{R}^{k}\),加性注意力评分函数:
其中可学习的参数是 \(\mathbf{W}_{q} \in \mathbb{R}^{h\times{q}}, \mathbf{W}_{k} \in \mathbb{R}^{h\times{k}}\) 和 \(\mathbf{w}_{h} \in \mathbb{R}^{h}\)。
#@save
class AdditiveAttention(nn.Module):
"""加性注意力"""
def __init__(self, key_size, query_size, num_hiddens, dropout):
super(AdditiveAttention, self).__init__()
self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
self.w_v = nn.Linear(num_hiddens, 1, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens):
# shape of queries: (`batch_size`, no. of queries, `query_size`)
# shape of keys: (`batch_size`, no. of key-value pairs, `key_size`)
# shape of values: (`batch_size`, no. of key-value pairs, `value_size`)
# shape of valid_lens: either (`batch_size`,) or (`batch_size`, no. of queries)
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion,
# shape of queries: (`batch_size`, no. of queries, 1, `num_hiddens`)
# shape of keys: (`batch_size`, 1, no. of key-value pairs, `num_hiddens`).
# 使用广播方式进行求和
features = queries.unsqueeze(2) + keys.unsqueeze(1)
features = torch.tanh(features)
# Shape of `scores`: (`batch_size`, no. of queries, no. of key-value pairs)
scores = self.w_v(features).squeeze(-1)
self.attention_weights = masked_softmax(scores, valid_lens)
# Output shape: (`batch_size`, no. of queries, `value_size`)
return torch.bmm(self.dropout(self.attention_weights), values)
缩放点积注意力¶
使用点积可以得到计算效率更高的评分函数,但是点积操作需要查询和键具有相同的长度 \(d\)。
假设查询和键的元素都是独立的随机变量,均值为0方差为1,那么两个向量的点积均值为0方差为 \(d\)。为了确保无论向量长度如何,注意力评分的方差均为1,点积需除以 \(\sqrt{d}\):
从小批量的角度,假设有 \(n\) 个查询 \(\mathbf{Q}\in\mathbb{R}^{n\times{d}}\),\(m\) 个键-值对 \(\mathbf{K}\in\mathbb{R}^{m\times{d}}, \mathbf{V}\in\mathbb{R}^{m\times{v}}\),缩放点积注意力为:
#@save
class DotProductAttention(nn.Module):
"""缩放点积注意力"""
def __init__(self, dropout):
super(DotProductAttention, self).__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens=None):
# Shape of queries: (`batch_size`, no. of queries, `d`)
# Shape of keys: (`batch_size`, no. of key-value pairs, `d`)
# Shape of values: (`batch_size`, no. of key-value pairs, `value_size`)
# Shape of valid_lens: (`batch_size`,) or (`batch_size`, no. of queries)
d = queries.shape[-1]
# Shape of `scores`: (`batch_size`, no. of queries, no. of key-value pairs)
scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
# Output shape: (`batch_size`, no. of queries, `value_size`)
return torch.bmm(self.dropout(self.attention_weights), values)