Multi-Head Attention改进
预计学习时间:20分钟
多头注意力(Multi-Head Attention)是Transformer架构的关键组件,允许模型同时关注序列的不同表示子空间。本节探讨多头注意力的局限性和主要改进方向,介绍提高其效率和表达能力的最新技术。
多头注意力机制回顾
标准多头注意力将查询、键、值向量投影到多个子空间,并行计算注意力,然后合并结果:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# 线性投影矩阵
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
batch_size = q.shape[0]
# 线性投影并拆分成多头
q = self.q_proj(q).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(k).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(v).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# 注意力计算
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, v)
# 重塑并合并多头输出
output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
return self.out_proj(output)
多头注意力的局限
尽管多头注意力是Transformer成功的关键因素,但它存在一些明显的局限性:
1. 参数冗余与计算效率
标准多头注意力的参数量和计算开销随头数线性增长:
def analyze_mha_parameters(d_model, num_heads):
"""分析多头注意力的参数量"""
# 每个头的维度
head_dim = d_model // num_heads
# 投影矩阵参数
q_params = d_model * d_model # Q投影
k_params = d_model * d_model # K投影
v_params = d_model * d_model # V投影
o_params = d_model * d_model # 输出投影
total_params = q_params + k_params + v_params + o_params
return {
"每个注意力头维度": head_dim,
"Q投影参数量": q_params,
"K投影参数量": k_params,
"V投影参数量": v_params,
"输出投影参数量": o_params,
"总参数量": total_params,
"占模型参数比例": f"约{total_params / (175e9) * 100:.2f}%" # 以175B模型为例
}
研究表明,并非所有注意力头都同等重要,很多头的作用有限或相互冗余,导致计算资源浪费。
2. 头部同质化现象
多头注意力存在头部同质化问题,不同头学习到相似的注意力模式:
def measure_head_diversity(model, dataset):
"""测量注意力头之间的多样性"""
# 提取所有层的所有头的注意力权重
attention_weights = extract_attention_weights(model, dataset)
# 计算头之间的相似度
head_similarities = []
for layer_idx, layer_weights in enumerate(attention_weights):
layer_similarities = []
num_heads = layer_weights.shape[1]
for i in range(num_heads):
for j in range(i+1, num_heads):
# 计算余弦相似度
sim = cosine_similarity(
layer_weights[:, i].reshape(-1),
layer_weights[:, j].reshape(-1)
)
layer_similarities.append(sim)
head_similarities.append(layer_similarities)
return head_similarities
3. 头数选择与缩放问题
头数选择是超参数,难以确定最优数量:
- 头数过少:限制模型捕获不同模式的能力
- 头数过多:增加计算开销,可能导致过拟合
多头注意力的主要改进方向
1. 混合专家注意力(Mixture of Experts Attention)
将多头注意力视为专家系统,每个头负责特定模式:
class MoEMultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, num_experts=None, top_k=2):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
num_experts = num_experts or num_heads
self.num_experts = num_experts
self.top_k = top_k
# 专家选择网络
self.router = nn.Linear(d_model, num_experts)
# 专家头
self.experts = nn.ModuleList([
AttentionHead(d_model, self.head_dim) for _ in range(num_experts)
])
self.output_proj = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
batch_size, seq_len, _ = q.shape
# 计算输入在每个专家上的路由分数
router_input = torch.mean(q, dim=1) # [batch_size, d_model]
router_logits = self.router(router_input) # [batch_size, num_experts]
# 选择top-k专家
expert_weights, expert_indices = torch.topk(router_logits, self.top_k, dim=-1)
expert_weights = F.softmax(expert_weights, dim=-1)
# 混合专家输出
outputs = torch.zeros(batch_size, seq_len, self.d_model, device=q.device)
for i in range(self.top_k):
# 为每个样本选择合适的专家
for b in range(batch_size):
expert_idx = expert_indices[b, i].item()
weight = expert_weights[b, i].unsqueeze(0).unsqueeze(0)
expert_out = self.experts[expert_idx](q[b:b+1], k[b:b+1], v[b:b+1], mask)
outputs[b:b+1] += weight * expert_out
return self.output_proj(outputs)
2. 动态头部修剪(Dynamic Head Pruning)
动态确定每个序列中最重要的头,忽略不重要的头:
class DynamicHeadPruning(nn.Module):
def __init__(self, d_model, num_heads, prune_ratio=0.5):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.prune_ratio = prune_ratio
# 标准多头注意力组件
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
# 头部重要性预测器
self.head_importance = nn.Linear(d_model, num_heads)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
batch_size = q.shape[0]
# 预测头部重要性
avg_q = q.mean(dim=1) # [batch_size, d_model]
head_scores = self.head_importance(avg_q) # [batch_size, num_heads]
# 选择最重要的头
keep_heads = int(self.num_heads * (1 - self.prune_ratio))
_, selected_heads = torch.topk(head_scores, keep_heads, dim=1) # [batch_size, keep_heads]
# 投影查询、键、值
q = self.q_proj(q).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(k).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(v).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# 计算所有头的注意力
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask.unsqueeze(1) == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, v) # [batch_size, num_heads, seq_len, head_dim]
# 创建批次特定的掩码,保留每个样本的选定头
head_mask = torch.zeros(batch_size, self.num_heads, 1, 1, device=q.device)
for b in range(batch_size):
head_mask[b, selected_heads[b]] = 1.0
# 应用头部掩码
attn_output = attn_output * head_mask
# 合并头部输出
output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
return self.out_proj(output)
3. 分组查询注意力(Grouped Query Attention)
将查询头分组,共享键值投影以减少计算开销:
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, num_heads, kv_groups=1):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.kv_groups = kv_groups
self.kv_heads = num_heads // kv_groups
assert num_heads % kv_groups == 0, "头数必须能被组数整除"
self.head_dim = d_model // num_heads
# 查询投影 - 标准多头
self.q_proj = nn.Linear(d_model, d_model)
# 键值投影 - 分组减少参数
self.k_proj = nn.Linear(d_model, self.kv_heads * self.head_dim)
self.v_proj = nn.Linear(d_model, self.kv_heads * self.head_dim)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
batch_size, seq_len, _ = q.shape
# 投影查询 - 完整头数
q = self.q_proj(q).view(batch_size, seq_len, self.num_heads, self.head_dim)
# 投影键值 - 减少的头数
k = self.k_proj(k).view(batch_size, -1, self.kv_heads, self.head_dim)
v = self.v_proj(v).view(batch_size, -1, self.kv_heads, self.head_dim)
# 重复键值头以匹配查询头数
k = k.repeat_interleave(self.kv_groups, dim=2)
v = v.repeat_interleave(self.kv_groups, dim=2)
# 调整维度顺序
q = q.transpose(1, 2) # [batch_size, num_heads, seq_len, head_dim]
k = k.transpose(1, 2) # [batch_size, num_heads, seq_len_k, head_dim]
v = v.transpose(1, 2) # [batch_size, num_heads, seq_len_v, head_dim]
# 注意力计算
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask.unsqueeze(1) == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, v)
# 合并头部输出
output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.out_proj(output)
分组查询注意力是LLaMA、PaLM等大型语言模型的关键优化,显著减少计算和内存需求同时保持性能。
4. 自适应多头注意力(Adaptive Multi-Head Attention)
根据输入内容动态调整注意力参数:
class AdaptiveMultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, head_options=[4, 8, 16]):
super().__init__()
self.d_model = d_model
self.max_heads = max(head_options)
self.head_options = head_options
# 头部计数预测器
self.head_predictor = nn.Sequential(
nn.Linear(d_model, 128),
nn.ReLU(),
nn.Linear(128, len(head_options))
)
# 投影矩阵 - 最大头数
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
batch_size, seq_len, _ = q.shape
# 预测最佳头数
avg_q = q.mean(dim=1) # [batch_size, d_model]
head_logits = self.head_predictor(avg_q) # [batch_size, len(head_options)]
head_choice = torch.argmax(head_logits, dim=1) # [batch_size]
# 为每个样本选择头数
results = []
for b in range(batch_size):
num_heads = self.head_options[head_choice[b].item()]
head_dim = self.d_model // num_heads
# 投影单个样本
q_b = self.q_proj(q[b:b+1]).view(1, seq_len, num_heads, head_dim).transpose(1, 2)
k_b = self.k_proj(k[b:b+1]).view(1, -1, num_heads, head_dim).transpose(1, 2)
v_b = self.v_proj(v[b:b+1]).view(1, -1, num_heads, head_dim).transpose(1, 2)
# 计算注意力
scores = torch.matmul(q_b, k_b.transpose(-2, -1)) / math.sqrt(head_dim)
if mask is not None:
scores = scores.masked_fill(mask[b:b+1].unsqueeze(1) == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, v_b)
# 重塑输出
output = attn_output.transpose(1, 2).contiguous().view(1, seq_len, self.d_model)
results.append(output)
# 合并批次结果
combined = torch.cat(results, dim=0)
return self.out_proj(combined)
多头注意力优化的实际应用
1. 头部重要性分析
识别和可视化不同头的重要性和功能:
def analyze_head_importance(model, dataset):
"""分析各注意力头的重要性"""
# 设置头重要性初始值
num_layers = model.config.num_hidden_layers
num_heads = model.config.num_attention_heads
head_importance = torch.zeros(num_layers, num_heads)
# 收集梯度信息
for batch in dataset:
model.zero_grad()
outputs = model(**batch)
loss = outputs.loss
loss.backward()
# 收集注意力头的梯度
for layer_idx in range(num_layers):
layer = model.encoder.layer[layer_idx].attention.self
q_grad = layer.query.weight.grad.abs().mean()
k_grad = layer.key.weight.grad.abs().mean()
v_grad = layer.value.weight.grad.abs().mean()
# 头部重要性用梯度大小表示
for head_idx in range(num_heads):
head_start = head_idx * (model.config.hidden_size // num_heads)
head_end = (head_idx + 1) * (model.config.hidden_size // num_heads)
head_importance[layer_idx, head_idx] += (
q_grad[head_start:head_end].mean() +
k_grad[head_start:head_end].mean() +
v_grad[head_start:head_end].mean()
).item()
# 归一化
head_importance = head_importance / len(dataset)
return head_importance
2. 头部剪枝实践
移除低重要性的头以提高效率:
def prune_heads(model, head_importance, prune_ratio=0.3):
"""基于重要性剪枝注意力头"""
num_layers, num_heads = head_importance.shape
heads_to_prune = {}
# 对每一层,选择最不重要的头进行剪枝
for layer_idx in range(num_layers):
# 确定该层需要剪枝的头数
num_to_prune = math.ceil(num_heads * prune_ratio)
# 找出最不重要的头的索引
layer_importance = head_importance[layer_idx]
_, indices = torch.topk(layer_importance, num_heads - num_to_prune, largest=True)
heads_to_prune[layer_idx] = set(range(num_heads)) - set(indices.tolist())
# 执行剪枝
for layer_idx, heads in heads_to_prune.items():
model.encoder.layer[layer_idx].attention.prune_heads(list(heads))
return model
3. 分组查询注意力在大模型中的应用
在GPT-4和LLaMA等模型中的实施例:
class LLaMAAttention(nn.Module):
"""LLaMA模型中的分组查询注意力实现示例"""
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads # 键值头数,小于查询头数
self.head_dim = self.hidden_size // self.num_heads
# 投影矩阵
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.kv_heads_ratio = self.num_heads // self.num_kv_heads
def forward(self, hidden_states, attention_mask=None, position_ids=None):
batch_size, seq_length = hidden_states.shape[:2]
# 投影查询、键、值
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# 重塑
query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim)
key_states = key_states.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
value_states = value_states.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
# 重复键值头以匹配查询头
if self.num_kv_heads != self.num_heads:
key_states = key_states.repeat_interleave(self.kv_heads_ratio, dim=2)
value_states = value_states.repeat_interleave(self.kv_heads_ratio, dim=2)
# 转置用于注意力计算
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# 标准注意力计算
attn_output = self._attention(query_states, key_states, value_states, attention_mask)
# 重塑输出
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
# 输出投影
return self.o_proj(attn_output)
def _attention(self, query, key, value, mask=None):
# 注意力实现...
pass
不同多头注意力改进方法的比较
改进方法 | 参数效率 | 计算效率 | 性能影响 | 实现复杂度 | 适用场景 |
---|---|---|---|---|---|
标准多头注意力 | 低 | 低 | 基准线 | 简单 | 通用 |
头部剪枝 | 中 | 中 | 轻微下降 | 简单 | 模型压缩 |
分组查询注意力 | 高 | 高 | 几乎无变化 | 中等 | 大规模模型 |
混合专家注意力 | 低 | 中 | 提升 | 复杂 | 复杂任务 |
自适应多头注意力 | 中 | 中 | 提升 | 复杂 | 动态场景 |
多头注意力改进的实战建议
1. 选择合适的改进策略
根据模型规模和应用场景选择多头注意力改进方法:
- 小型模型( < 100M参数):标准多头注意力或简单头部剪枝
- 中型模型(100M - 10B参数):分组查询注意力,降低计算量
- 大型模型( > 10B参数):分组查询注意力 + 专家混合,兼顾效率和性能
2. 注意头数与维度的权衡
def optimize_head_configuration(d_model):
"""为给定模型维度推荐合理的头数配置"""
configurations = []
# 测试不同的头数配置
for num_heads in [1, 2, 4, 8, 12, 16, 24, 32, 64]:
if d_model % num_heads != 0:
continue
head_dim = d_model // num_heads
# 计算分组查询配置
for kv_groups in [1, 2, 4, 8]:
if num_heads % kv_groups != 0:
continue
kv_heads = num_heads // kv_groups
# 参数量和计算量估计
param_count = (
d_model * d_model + # Q投影
d_model * (d_model // kv_groups) * 2 + # K、V投影
d_model * d_model # 输出投影
)
# 加入推荐列表
configurations.append({
"d_model": d_model,
"num_heads": num_heads,
"head_dim": head_dim,
"kv_groups": kv_groups,
"kv_heads": kv_heads,
"param_count": param_count,
"param_reduction": 1 - (param_count / (4 * d_model * d_model)),
})
# 按参数量排序
return sorted(configurations, key=lambda x: x["param_count"])
3. 多头注意力的可视化与分析
def visualize_attention_heads(model, text_input):
"""可视化多头注意力模式"""
# 准备输入
inputs = tokenizer(text_input, return_tensors="pt")
# 获取注意力权重
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
# 提取所有层的注意力权重
attentions = outputs.attentions # tuple of (batch, num_heads, seq_len, seq_len)
# 可视化每一层的每个头
for layer_idx, layer_attention in enumerate(attentions):
for head_idx in range(layer_attention.shape[1]):
attn_weights = layer_attention[0, head_idx].numpy()
# 绘制热力图
plt.figure(figsize=(10, 8))
plt.imshow(attn_weights, cmap="viridis")
plt.title(f"Layer {layer_idx}, Head {head_idx}")
plt.xlabel("Key position")
plt.ylabel("Query position")
plt.colorbar()
plt.tight_layout()
plt.show()
# 分析头部专注模式
analyze_head_pattern(attn_weights, layer_idx, head_idx)
小结
多头注意力的改进对大语言模型的效率和性能至关重要:
- 参数效率:分组查询注意力、头部剪枝等技术显著减少参数量
- 计算优化:改进的多头机制降低计算复杂度,支持更长序列处理
- 表达能力:专家混合和自适应机制增强模型捕获复杂模式的能力
- 工程实践:根据模型规模和应用场景选择适当的多头注意力变体
随着大语言模型规模的不断增长,多头注意力改进将继续发挥核心作用,推动模型效率和性能的提升。