位置编码¶
为了使用序列的顺序信息,我们通过在输入表示中添加位置编码(positional encoding)来注入绝对的或相对的位置信息。
假设输入为 X∈Rn×d,位置编码使用相同形状的位置嵌入矩阵 P∈Rn×d 输出 X+P,其中:
pi,2j=sin(i100002j/d)
pi,2j+1=cos(i100002j/d)
行用 sin,cos 的位置来表示。
列用 sin,cos 的频率来表示。
import torch
from torch import nn
#@save
class PositionalEncoding(nn.Module):
"""位置编码"""
def __init__(self, num_hiddens, dropout, max_len=1000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
# 创建一个足够长的 `P`
self.P = torch.zeros((1, max_len, num_hiddens))
X = torch.arange(max_len, dtype=torch.float32).reshape(
-1, 1) / torch.pow(10000, torch.arange(
0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
# `num_hideens`必须为偶数,不然shape对不上
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
位置编码就像是二进制表示: