Multi-Head Attention
预计学习时间:20分钟
多头注意力(Multi-Head Attention)是Transformer架构的核心创新之一,通过将注意力机制分为多个"头",使模型能够同时关注不同的表示子空间,提升模型的表达能力和性能。
多头注意力的基本原理
多头注意力通过将输入投影到多个子空间,并在每个子空间独立计算注意力,然后将结果合并:
其中:
每个注意力头的维度通常是模型维度的分数,如果有8个头,则每个头的维度是模型维度的1/8。这样保持了计算量大致不变,同时获得了多头的优势。
多头注意力的计算过程
1. 线性投影
将查询、键、值向量投影到多个子空间:
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个头的维度
# 线性投影层
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.output_linear = nn.Linear(d_model, d_model)
2. 分割头部
将投影后的向量分割为多个头:
def split_heads(self, x, batch_size):
"""将张量分割成多头形式"""
# x: [batch_size, seq_len, d_model]
x = x.view(batch_size, -1, self.num_heads, self.d_k)
# [batch_size, seq_len, num_heads, d_k]
return x.transpose(1, 2)
# [batch_size, num_heads, seq_len, d_k]
3. 计算注意力
在每个头上独立计算注意力:
def scaled_dot_product_attention(self, q, k, v, mask=None):
"""在每个头上计算缩放点积注意力"""
# q, k, v: [batch_size, num_heads, seq_len, d_k]
# 计算注意力分数
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
# [batch_size, num_heads, seq_len_q, seq_len_k]
# 应用掩码(如果提供)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 计算注意力权重
attention_weights = torch.softmax(scores, dim=-1)
# [batch_size, num_heads, seq_len_q, seq_len_k]
# 应用注意力权重
output = torch.matmul(attention_weights, v)
# [batch_size, num_heads, seq_len_q, d_k]
return output, attention_weights
4. 合并头部
将多个头的输出合并:
def combine_heads(self, x, batch_size):
"""将多头形式的张量合并回原始形状"""
# x: [batch_size, num_heads, seq_len, d_k]
x = x.transpose(1, 2)
# [batch_size, seq_len, num_heads, d_k]
x = x.contiguous().view(batch_size, -1, self.d_model)
# [batch_size, seq_len, d_model]
return x
5. 最终的线性变换
应用最终的线性变换层:
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
# 线性投影
q = self.q_linear(q) # [batch_size, seq_len_q, d_model]
k = self.k_linear(k) # [batch_size, seq_len_k, d_model]
v = self.v_linear(v) # [batch_size, seq_len_v, d_model]
# 分割成多个头
q = self.split_heads(q, batch_size) # [batch_size, num_heads, seq_len_q, d_k]
k = self.split_heads(k, batch_size) # [batch_size, num_heads, seq_len_k, d_k]
v = self.split_heads(v, batch_size) # [batch_size, num_heads, seq_len_v, d_k]
# 应用缩放点积注意力
attn_output, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)
# attn_output: [batch_size, num_heads, seq_len_q, d_k]
# 合并多个头
output = self.combine_heads(attn_output, batch_size) # [batch_size, seq_len_q, d_model]
# 最终的线性层
output = self.output_linear(output) # [batch_size, seq_len_q, d_model]
return output, attention_weights
多头注意力的完整实现
将上述步骤整合为一个完整的PyTorch模块:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 线性投影层
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.output_linear = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def split_heads(self, x, batch_size):
x = x.view(batch_size, -1, self.num_heads, self.d_k)
return x.transpose(1, 2)
def combine_heads(self, x, batch_size):
x = x.transpose(1, 2).contiguous()
return x.view(batch_size, -1, self.d_model)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
# 线性投影
q = self.q_linear(q)
k = self.k_linear(k)
v = self.v_linear(v)
# 分割成多个头
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
# 计算注意力
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask.unsqueeze(1) == 0, -1e9)
attention_weights = torch.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
attn_output = torch.matmul(attention_weights, v)
# 合并多个头
output = self.combine_heads(attn_output, batch_size)
# 最终的线性层
output = self.output_linear(output)
return output, attention_weights
多头注意力的优势
1. 增强表示能力
多头注意力允许模型同时关注不同的表示子空间:
def analyze_head_specialization(model, sentence):
"""分析不同注意力头的专门化程度"""
tokens = tokenizer(sentence, return_tensors="pt")
with torch.no_grad():
outputs = model(**tokens, output_attentions=True)
# 提取某一层的注意力权重
layer_attention = outputs.attentions[6] # 第7层
# 分析不同头关注的位置
specializations = []
for head_idx in range(layer_attention.shape[1]):
head_attention = layer_attention[0, head_idx].numpy()
# 计算注意力熵(分散程度)
entropy = -np.sum(
head_attention * np.log(head_attention + 1e-9),
axis=1
).mean()
# 提取最强的注意力连接
strongest_connections = []
for i in range(head_attention.shape[0]):
max_idx = np.argmax(head_attention[i])
strongest_connections.append((i, max_idx))
specializations.append({
"head_idx": head_idx,
"entropy": entropy,
"strongest_connections": strongest_connections
})
return specializations
2. 提高稳定性
多头结构降低了单个注意力机制的随机性影响:
def test_attention_stability(model, sentence, num_runs=10, noise_level=0.01):
"""测试多头注意力对输入噪声的稳定性"""
tokens = tokenizer(sentence, return_tensors="pt")
input_embeds = model.get_input_embeddings()(tokens.input_ids)
# 原始输出
with torch.no_grad():
original_output = model(inputs_embeds=input_embeds).last_hidden_state
# 添加噪声后的输出
noisy_outputs = []
for _ in range(num_runs):
noise = torch.randn_like(input_embeds) * noise_level
noisy_embeds = input_embeds + noise
with torch.no_grad():
noisy_output = model(inputs_embeds=noisy_embeds).last_hidden_state
noisy_outputs.append(noisy_output)
# 计算稳定性指标(输出方差)
variances = torch.stack([
torch.mean((output - original_output)**2)
for output in noisy_outputs
])
return {
"mean_variance": variances.mean().item(),
"max_variance": variances.max().item()
}
3. 并行计算效率
多个头可以并行计算,提高硬件利用率:
# 高效的多头注意力实现,利用批处理
def efficient_multi_head_attention(q, k, v, num_heads, d_model, mask=None):
batch_size, seq_len, _ = q.shape
d_k = d_model // num_heads
# 线性投影和分割头 (一次性完成)
q = q.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)
k = k.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)
v = v.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)
# 批量计算注意力
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask.unsqueeze(1) == 0, -1e9)
weights = torch.softmax(scores, dim=-1)
attn_output = torch.matmul(weights, v)
# 合并头
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, seq_len, d_model
)
return attn_output
多头注意力的应用
1. 在Transformer编码器中的应用
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# 多头自注意力
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
# 前馈网络
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
# 层归一化
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 多头自注意力 + 残差连接
attn_output, _ = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# 前馈网络 + 残差连接
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
2. 在Transformer解码器中的应用
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# 带掩码的多头自注意力
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
# 编码器-解码器多头注意力
self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
# 前馈网络
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
# 层归一化
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, tgt_mask=None, src_mask=None):
# 带掩码的多头自注意力 + 残差连接
attn_output, _ = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_output))
# 编码器-解码器多头注意力 + 残差连接
attn_output, _ = self.cross_attn(x, enc_output, enc_output, src_mask)
x = self.norm2(x + self.dropout(attn_output))
# 前馈网络 + 残差连接
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_output))
return x
多头注意力的变体
1. 可分离的多头注意力
每个头使用不同参数,提高表达能力但增加参数量:
class SeparableMultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
self.num_heads = num_heads
self.d_model = d_model
self.d_k = d_model // num_heads
# 每个头有独立的投影矩阵
self.q_linears = nn.ModuleList([
nn.Linear(d_model, self.d_k) for _ in range(num_heads)
])
self.k_linears = nn.ModuleList([
nn.Linear(d_model, self.d_k) for _ in range(num_heads)
])
self.v_linears = nn.ModuleList([
nn.Linear(d_model, self.d_k) for _ in range(num_heads)
])
self.output_linear = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
# 存储每个头的输出
head_outputs = []
for i in range(self.num_heads):
# 独立投影
q_i = self.q_linears[i](q) # [batch_size, seq_len, d_k]
k_i = self.k_linears[i](k)
v_i = self.v_linears[i](v)
# 计算注意力
scores = torch.matmul(q_i, k_i.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
weights = torch.softmax(scores, dim=-1)
weights = self.dropout(weights)
head_output = torch.matmul(weights, v_i)
head_outputs.append(head_output)
# 拼接多头输出
output = torch.cat(head_outputs, dim=-1)
# 最终线性变换
output = self.output_linear(output)
return output
2. 带相对位置编码的多头注意力
加入相对位置信息,增强序列建模能力:
class RelativeMultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, max_len=512, dropout=0.1):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 线性投影
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.output_linear = nn.Linear(d_model, d_model)
# 相对位置编码
self.rel_pos_encoding = nn.Parameter(
torch.randn(2 * max_len - 1, self.d_k)
)
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
batch_size, seq_len, _ = q.shape
# 线性投影
q = self.q_linear(q).view(batch_size, seq_len, self.num_heads, self.d_k)
k = self.k_linear(k).view(batch_size, seq_len, self.num_heads, self.d_k)
v = self.v_linear(v).view(batch_size, seq_len, self.num_heads, self.d_k)
# 转置维度
q = q.transpose(1, 2) # [batch_size, num_heads, seq_len, d_k]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# 内容内容注意力
content_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
# 内容位置注意力 (简化实现)
positions = torch.arange(seq_len, device=q.device)
rel_pos_indices = positions.unsqueeze(1) - positions.unsqueeze(0) + seq_len - 1
rel_pos = self.rel_pos_encoding[rel_pos_indices]
position_scores = torch.matmul(
q.transpose(2, 3),
rel_pos.unsqueeze(0).unsqueeze(1).transpose(3, 4)
).squeeze(-1).transpose(2, 3)
# 组合两种注意力分数
scores = content_scores + position_scores
if mask is not None:
scores = scores.masked_fill(mask.unsqueeze(1) == 0, -1e9)
weights = torch.softmax(scores, dim=-1)
weights = self.dropout(weights)
output = torch.matmul(weights, v)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
output = self.output_linear(output)
return output
多头注意力的分析
注意力头的可视化
分析不同头关注的模式:
def visualize_attention_heads(model, text, layer_idx=6):
"""可视化多头注意力的注意力模式"""
# 标记化输入
tokens = tokenizer(text, return_tensors="pt")
token_strings = tokenizer.convert_ids_to_tokens(tokens.input_ids[0])
# 获取注意力权重
with torch.no_grad():
outputs = model(**tokens, output_attentions=True)
# 提取指定层的注意力权重
attention = outputs.attentions[layer_idx][0] # [num_heads, seq_len, seq_len]
# 创建注意力可视化
fig, axes = plt.subplots(
int(math.ceil(attention.shape[0] / 4)), 4,
figsize=(20, 5 * math.ceil(attention.shape[0] / 4))
)
axes = axes.flatten()
for head_idx in range(attention.shape[0]):
if head_idx < len(axes):
ax = axes[head_idx]
im = ax.imshow(attention[head_idx], cmap="viridis")
ax.set_title(f"Head {head_idx}")
ax.set_xticks(range(len(token_strings)))
ax.set_yticks(range(len(token_strings)))
ax.set_xticklabels(token_strings, rotation=90)
ax.set_yticklabels(token_strings)
plt.colorbar(im, ax=ax)
plt.tight_layout()
plt.show()
不同头关注的信息类型
多头注意力的不同头可能关注不同类型的信息:
头类型 | 倾向于捕获 | 常见观察结果 |
---|---|---|
语法头 | 语法依赖关系 | 关注主语-谓语、介词-名词等 |
语义头 | 语义相关性 | 关注相关概念、同义词、上下文 |
位置头 | 位置关系 | 关注相邻token或固定距离的token |
全局头 | 全局信息 | 几乎均匀地关注所有token |
小结
多头注意力通过并行计算多个独立的注意力机制,提高了模型的表达能力和稳定性:
- 核心思想:将单个高维注意力拆分为多个并行的低维注意力
- 主要优势:
- 增强模型表达能力,捕获不同类型的依赖关系
- 提高模型稳定性和鲁棒性
- 实现高效的并行计算
- 实现关键:投影矩阵、头部分割与合并、缩放点积注意力计算
- 应用场景:Transformer编码器和解码器、多模态模型、图像处理等
多头注意力是大语言模型成功的关键因素之一,其设计理念也启发了众多注意力机制的变体和改进。随着模型规模增长,如何高效地实现和优化多头注意力成为了一个重要的研究方向。