Self Attention
预计学习时间:25分钟
自注意力(Self Attention)是Transformer架构的核心组件,能够捕获序列内部的依赖关系。与传统的RNN不同,自注意力通过直接计算序列中所有位置之间的关联,实现了并行计算和捕获长距离依赖的能力。
自注意力的基本原理
自注意力机制的核心思想是让序列中的每个元素都能"关注"序列中的所有其他元素,并根据关联程度分配不同的注意力权重:
其中:
(查询):当前位置的表示,用于与其他位置进行匹配 (键):所有位置的表示,用于被查询匹配 (值):所有位置的表示,用于信息聚合 :键向量的维度,用于缩放点积以避免梯度消失
在自注意力中,
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
# 线性变换矩阵
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
# x: [batch_size, seq_len, d_model]
batch_size, seq_len, _ = x.size()
# 线性变换
q = self.q_linear(x) # [batch_size, seq_len, d_model]
k = self.k_linear(x) # [batch_size, seq_len, d_model]
v = self.v_linear(x) # [batch_size, seq_len, d_model]
# 计算注意力分数
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_model)
# [batch_size, seq_len, seq_len]
# 应用掩码(如果提供)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 应用softmax得到注意力权重
attn_weights = F.softmax(scores, dim=-1) # [batch_size, seq_len, seq_len]
# 加权汇总值向量
output = torch.matmul(attn_weights, v) # [batch_size, seq_len, d_model]
return self.out(output)
自注意力的计算过程
1. 输入嵌入和线性投影
将输入序列转换为嵌入向量后,通过线性变换生成查询(Q)、键(K)和值(V):
def generate_qkv(input_embeddings, w_q, w_k, w_v):
"""生成查询、键、值向量"""
# input_embeddings: [batch_size, seq_len, d_model]
# w_q, w_k, w_v: [d_model, d_model]
q = torch.matmul(input_embeddings, w_q) # [batch_size, seq_len, d_model]
k = torch.matmul(input_embeddings, w_k) # [batch_size, seq_len, d_model]
v = torch.matmul(input_embeddings, w_v) # [batch_size, seq_len, d_model]
return q, k, v
2. 计算注意力分数
计算查询与所有键的点积,并进行缩放:
def compute_attention_scores(q, k, scale=True):
"""计算注意力分数矩阵"""
# q, k: [batch_size, seq_len, d_model]
# 计算点积
scores = torch.matmul(q, k.transpose(-2, -1)) # [batch_size, seq_len, seq_len]
# 缩放点积
if scale:
d_k = q.size(-1)
scores = scores / math.sqrt(d_k)
return scores
3. 应用掩码(可选)
对于需要掩盖的位置(如未来位置或填充位置),将分数设为负无穷:
def apply_mask(scores, mask):
"""应用注意力掩码"""
# scores: [batch_size, seq_len, seq_len]
# mask: [batch_size, seq_len, seq_len] 二进制掩码
masked_scores = scores.masked_fill(mask == 0, -1e9)
return masked_scores
掩码在自注意力中至关重要,特别是在解码器中,我们需要防止模型看到未来的信息。
4. Softmax归一化
将注意力分数转换为概率分布:
def normalize_scores(scores):
"""将分数归一化为概率分布"""
# scores: [batch_size, seq_len, seq_len]
attention_weights = F.softmax(scores, dim=-1)
return attention_weights
5. 加权汇总
使用注意力权重对值向量进行加权求和:
def weighted_aggregation(attention_weights, v):
"""基于注意力权重聚合值向量"""
# attention_weights: [batch_size, seq_len, seq_len]
# v: [batch_size, seq_len, d_model]
output = torch.matmul(attention_weights, v) # [batch_size, seq_len, d_model]
return output
自注意力的特点
优势
-
并行计算:与RNN不同,自注意力可以并行计算所有位置,提高训练效率
-
捕获长距离依赖:每个位置都可以直接关注任何其他位置,无论距离多远
def analyze_attention_span(attention_weights):
"""分析自注意力的关注范围"""
# attention_weights: [batch_size, seq_len, seq_len]
# 计算平均注意力距离
batch_size, seq_len, _ = attention_weights.shape
# 创建距离矩阵
positions = torch.arange(seq_len).unsqueeze(0).unsqueeze(2) # [1, seq_len, 1]
positions_t = positions.transpose(1, 2) # [1, 1, seq_len]
distance = torch.abs(positions - positions_t) # [1, seq_len, seq_len]
# 计算加权平均距离
avg_distance = torch.sum(attention_weights * distance, dim=2) # [batch_size, seq_len]
return avg_distance.mean()
- 表示能力强:可以捕获复杂的序列内部关系,包括语法和语义依赖
局限
-
二次计算复杂度:注意力矩阵大小为
,随序列长度二次增长 -
没有位置信息:需要额外的位置编码提供顺序信息
自注意力本身是位置无关的,即打乱序列的顺序不会改变结果。因此,必须引入位置编码以提供顺序信息。
变体和扩展
1. 带掩码的自注意力
在解码器中使用,确保预测时只使用当前及之前的信息:
def causal_self_attention(x):
"""带因果掩码的自注意力,用于自回归生成"""
batch_size, seq_len, d_model = x.shape
# 创建下三角掩码
mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)
# 计算自注意力,应用掩码
q = self.q_linear(x)
k = self.k_linear(x)
v = self.v_linear(x)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_model)
scores = scores.masked_fill(mask == 0, -1e9)
weights = F.softmax(scores, dim=-1)
output = torch.matmul(weights, v)
return output
2. 局部敏感哈希注意力
减少计算复杂度的变体,适用于长序列:
def lsh_attention(q, k, v, n_buckets=256, n_hashes=8):
"""局部敏感哈希注意力,降低长序列的计算复杂度"""
# 简化示例
batch_size, seq_len, d_model = q.shape
# 哈希函数(简化表示)
def hash_vectors(vectors, n_buckets, n_hashes):
# 在实际实现中,这应该使用随机投影等方法
projections = torch.randn(n_hashes, d_model, n_buckets // 2)
projected = torch.matmul(vectors.unsqueeze(1), projections)
hashed = torch.argmax(torch.cat([projected, -projected], dim=-1), dim=-1)
return hashed
# 对查询和键进行哈希
q_buckets = hash_vectors(q, n_buckets, n_hashes) # [batch_size, n_hashes, seq_len]
k_buckets = hash_vectors(k, n_buckets, n_hashes)
# 后续计算:只计算同一桶内的注意力,此处省略复杂实现...
return output
自注意力的应用案例
1. 自然语言处理
自注意力能够捕获语法和语义依赖关系:
def analyze_syntactic_attention(model, sentence):
"""分析自注意力如何捕获句法结构"""
# 对句子进行标记化
tokens = tokenizer(sentence, return_tensors="pt")
# 前向传播,获取注意力权重
with torch.no_grad():
outputs = model(**tokens, output_attentions=True)
# 取出特定层的注意力权重
layer_attentions = outputs.attentions[6] # 第7层注意力
# 可视化句法关系
for head_idx in [0, 3, 5]: # 选择几个特定的注意力头
attn = layer_attentions[0, head_idx].numpy()
# 绘制热力图
plt.figure(figsize=(10, 8))
plt.imshow(attn, cmap="viridis")
plt.xticks(range(len(tokens)), tokens, rotation=90)
plt.yticks(range(len(tokens)), tokens)
plt.title(f"Head {head_idx} Attention")
plt.colorbar()
plt.tight_layout()
plt.show()
2. 计算机视觉
自注意力用于捕获图像中的长距离依赖:
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3,
embed_dim=768, depth=12, num_heads=12):
super().__init__()
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
num_patches = (img_size // patch_size) ** 2
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads) for _ in range(depth)
])
def forward(self, x):
# x: [batch_size, channels, height, width]
x = self.patch_embed(x) # [batch_size, num_patches, embed_dim]
# 添加分类令牌
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat([cls_token, x], dim=1)
# 添加位置编码
x = x + self.pos_embed
# 应用Transformer块
for block in self.blocks:
x = block(x)
# 使用[CLS]令牌的表示进行分类
return x[:, 0]
自注意力的效率和优化
计算复杂度分析
标准自注意力的计算复杂度:
操作 | 时间复杂度 | 空间复杂度 |
---|---|---|
QK^T计算 | O(n²d) | O(n²) |
Softmax | O(n²) | O(n²) |
乘以V | O(n²d) | O(nd) |
总计 | O(n²d) | O(n²) |
其中n是序列长度,d是隐藏维度。
优化方法
对于长序列处理,可以采用以下优化方法:
def efficient_attention(q, k, v, chunk_size=128):
"""分块计算自注意力以优化内存使用"""
batch_size, seq_len, d_model = q.shape
# 初始化输出
output = torch.zeros_like(q)
# 分块计算
for i in range(0, seq_len, chunk_size):
end_idx = min(i + chunk_size, seq_len)
# 当前块的查询
q_chunk = q[:, i:end_idx]
# 计算当前块与所有键的注意力
chunk_scores = torch.matmul(q_chunk, k.transpose(-2, -1)) / math.sqrt(d_model)
chunk_weights = F.softmax(chunk_scores, dim=-1)
# 聚合值向量
chunk_output = torch.matmul(chunk_weights, v)
# 更新输出
output[:, i:end_idx] = chunk_output
return output
小结
自注意力机制是Transformer架构的基础,通过直接计算序列内部的依赖关系,实现了捕获长距离依赖和并行计算的能力:
- 核心优势:全局上下文建模、捕获长距离依赖、并行计算
- 关键计算步骤:生成查询/键/值、计算注意力分数、归一化、加权聚合
- 主要挑战:二次计算复杂度、缺乏内在的位置信息
- 优化方向:降低计算复杂度、提高长序列处理能力
随着大语言模型的发展,自注意力的效率优化变得越来越重要,各种稀疏注意力和线性复杂度注意力变体也应运而生,成为推动大规模模型发展的关键技术。