Self-Attention
预计学习时间:30分钟
**自注意力(Self-Attention)**是Transformer架构的核心机制,它允许模型在处理序列时关注序列中的所有位置,为每个位置计算加权上下文表示。通过学习输入序列中元素之间的关联强度,自注意力机制能有效捕获长距离依赖关系,大幅提升模型对序列数据的处理能力。
自注意力机制的基本原理
自注意力的核心思想是让序列中的每个元素能够"看到"整个序列,并有选择性地关注相关部分:
注意力计算流程
自注意力通过以下步骤计算:
- 线性投影:将输入向量转换为查询(Q)、键(K)、值(V)三种表示
- 注意力分数:计算查询与所有键的相似度作为注意力分数
- 权重归一化:对分数应用softmax函数获得概率分布
- 加权聚合:使用归一化权重对值向量进行加权求和
# 自注意力机制的基本实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
def __init__(self, embed_size):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
# 线性变换矩阵
self.query = nn.Linear(embed_size, embed_size)
self.key = nn.Linear(embed_size, embed_size)
self.value = nn.Linear(embed_size, embed_size)
# 输出投影
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, x, mask=None):
# x: [batch_size, seq_len, embed_size]
batch_size = x.shape[0]
seq_length = x.shape[1]
# 1. 获取查询、键、值向量
queries = self.query(x) # [batch_size, seq_len, embed_size]
keys = self.key(x) # [batch_size, seq_len, embed_size]
values = self.value(x) # [batch_size, seq_len, embed_size]
# 2. 计算注意力分数 (点积注意力)
# 使用爱因斯坦求和约定进行批量矩阵乘法
energy = torch.einsum("nqd,nkd->nqk", [queries, keys])
# energy: [batch_size, seq_len, seq_len]
# 3. 缩放注意力分数
energy = energy / math.sqrt(self.embed_size)
# 4. 掩蔽填充位置(可选)
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
# 5. Softmax归一化获取注意力权重
attention = F.softmax(energy, dim=2) # [batch_size, seq_len, seq_len]
# 6. 加权聚合值向量
out = torch.einsum("nqk,nkd->nqd", [attention, values])
# out: [batch_size, seq_len, embed_size]
# 7. 输出线性变换
out = self.fc_out(out)
return out, attention
在实践中,自注意力的计算复杂度是O(n²),其中n是序列长度。这意味着对于长序列,计算成本会迅速增加,成为Transformer模型的主要瓶颈之一。
自注意力的数学表达
标准自注意力机制可以用以下数学公式表示:
其中:
分别是查询、键和值矩阵 是键向量的维度 是缩放因子,防止点积过大导致softmax梯度消失
这种形式的注意力被称为标准点积注意力(Scaled Dot-Product Attention)。
"自注意力能够学习到输入序列中所有位置之间的依赖关系,无论它们之间的距离有多远,这使得它在捕获长距离依赖方面比RNN更有优势。" — Vaswani et al.
各种自注意力变体
自注意力机制有多种变体和扩展,各有特点:
1. 点积注意力与加性注意力
注意力类型 | 计算方式 | 特点 |
---|---|---|
点积注意力 | 计算高效,尤其在高维向量上 | |
加性注意力 | 更适合低维向量,数值稳定性更好 |
2. 硬注意力与软注意力
- 软注意力(Soft Attention):分配连续的注意力权重,完全可微
- 硬注意力(Hard Attention):仅关注单个位置,使用采样,需要强化学习
3. 稀疏注意力变体
为了解决自注意力的计算复杂度问题,多种稀疏注意力机制被提出:
# 简单的局部注意力实现(仅关注上下文窗口)
def local_attention(queries, keys, values, window_size=5, mask=None):
batch_size, seq_len, dim = queries.shape
# 初始化输出
outputs = torch.zeros_like(queries)
# 对每个位置计算局部注意力
for i in range(seq_len):
# 确定局部窗口范围
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2 + 1)
# 获取局部键值
local_keys = keys[:, start:end, :]
local_values = values[:, start:end, :]
# 当前查询
q = queries[:, i:i+1, :] # [batch_size, 1, dim]
# 计算注意力分数
scores = torch.bmm(q, local_keys.transpose(1, 2)) / math.sqrt(dim)
# 应用掩码(如果有)
if mask is not None:
local_mask = mask[:, i:i+1, start:end]
scores = scores.masked_fill(local_mask == 0, float("-1e20"))
# Softmax获取权重
attn_weights = F.softmax(scores, dim=2) # [batch_size, 1, window_size]
# 加权求和
context = torch.bmm(attn_weights, local_values) # [batch_size, 1, dim]
outputs[:, i:i+1, :] = context
return outputs
自注意力的优势与局限
自注意力相比传统序列模型带来多方面优势:
主要优势
- 全局依赖捕获:直接建立任意距离的token间联系
- 并行计算:所有位置可以同时计算
- 直观可解释:注意力权重提供了模型决策的可视化解释
- 灵活适应:同样的机制适用于不同长度的序列
关键局限
- 二次计算复杂度:计算需求随序列长度平方增长
- 位置信息缺失:需要额外的位置编码提供顺序信息
- 高内存需求:需要存储完整的注意力矩阵
# 可视化自注意力权重
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
def visualize_attention(tokens, attention_weights):
"""可视化自注意力权重"""
plt.figure(figsize=(10, 8))
sns.heatmap(
attention_weights,
xticklabels=tokens,
yticklabels=tokens,
cmap="YlGnBu",
annot=False
)
plt.title("Self-Attention Weights")
plt.xlabel("Key Tokens")
plt.ylabel("Query Tokens")
plt.tight_layout()
return plt
# 示例使用
tokens = ["The", "cat", "sits", "on", "the", "mat", "."]
# 随机生成的注意力权重矩阵
random_attention = np.random.rand(len(tokens), len(tokens))
random_attention = random_attention / random_attention.sum(axis=1, keepdims=True)
attention_plot = visualize_attention(tokens, random_attention)
自注意力在不同任务中的应用
自注意力已成为多种NLP和计算机视觉任务的关键组件:
语言建模
- 捕获词之间的上下文依赖关系
- 处理长距离语法关系和指代消解
机器翻译
- 在编码器中建立源语言词之间的关系
- 在编码器-解码器注意力中关联目标语言和源语言
文本摘要
- 识别文本中的关键信息
- 学习句子间的关系
计算机视觉
- 处理图像作为"像素序列"
- 捕获图像不同区域之间的关系
实现中的关键技巧
在实际实现自注意力时,一些技巧至关重要:
1. 掩码操作
- 填充掩码(Padding Mask):防止模型关注填充位置
- 未来掩码(Future Mask):在自回归生成中防止信息泄露
2. 注意力丢弃(Attention Dropout)
在softmax后应用Dropout,增加模型鲁棒性:
# 带Dropout的自注意力
def attention_with_dropout(query, key, value, mask=None, dropout=None):
"点积注意力实现,带掩码和dropout"
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn = F.softmax(scores, dim=-1)
if dropout is not None:
attn = dropout(attn)
return torch.matmul(attn, value), attn
自注意力机制作为Transformer和现代大型语言模型的核心部件,彻底改变了序列数据建模的方式。通过理解其工作原理和优化手段,我们能更好地理解和利用这些强大模型的内部机制。