Multi-Head Attention改进

预计学习时间:20分钟

多头注意力(Multi-Head Attention)是Transformer架构的关键组件,允许模型同时关注序列的不同表示子空间。本节探讨多头注意力的局限性和主要改进方向,介绍提高其效率和表达能力的最新技术。

多头注意力机制回顾

标准多头注意力将查询、键、值向量投影到多个子空间,并行计算注意力,然后合并结果:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        # 线性投影矩阵
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask=None):
        batch_size = q.shape[0]
        
        # 线性投影并拆分成多头
        q = self.q_proj(q).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(k).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(v).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 注意力计算
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)
        
        # 重塑并合并多头输出
        output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.out_proj(output)

多头注意力机制示意图

多头注意力的局限

尽管多头注意力是Transformer成功的关键因素,但它存在一些明显的局限性:

1. 参数冗余与计算效率

标准多头注意力的参数量和计算开销随头数线性增长:

def analyze_mha_parameters(d_model, num_heads):
    """分析多头注意力的参数量"""
    # 每个头的维度
    head_dim = d_model // num_heads
    
    # 投影矩阵参数
    q_params = d_model * d_model  # Q投影
    k_params = d_model * d_model  # K投影
    v_params = d_model * d_model  # V投影
    o_params = d_model * d_model  # 输出投影
    
    total_params = q_params + k_params + v_params + o_params
    
    return {
        "每个注意力头维度": head_dim,
        "Q投影参数量": q_params,
        "K投影参数量": k_params,
        "V投影参数量": v_params,
        "输出投影参数量": o_params,
        "总参数量": total_params,
        "占模型参数比例": f"约{total_params / (175e9) * 100:.2f}%" # 以175B模型为例
    }

研究表明,并非所有注意力头都同等重要,很多头的作用有限或相互冗余,导致计算资源浪费。

2. 头部同质化现象

多头注意力存在头部同质化问题,不同头学习到相似的注意力模式:

def measure_head_diversity(model, dataset):
    """测量注意力头之间的多样性"""
    # 提取所有层的所有头的注意力权重
    attention_weights = extract_attention_weights(model, dataset)
    
    # 计算头之间的相似度
    head_similarities = []
    for layer_idx, layer_weights in enumerate(attention_weights):
        layer_similarities = []
        num_heads = layer_weights.shape[1]
        
        for i in range(num_heads):
            for j in range(i+1, num_heads):
                # 计算余弦相似度
                sim = cosine_similarity(
                    layer_weights[:, i].reshape(-1),
                    layer_weights[:, j].reshape(-1)
                )
                layer_similarities.append(sim)
        
        head_similarities.append(layer_similarities)
    
    return head_similarities

3. 头数选择与缩放问题

头数选择是超参数,难以确定最优数量:

  • 头数过少:限制模型捕获不同模式的能力
  • 头数过多:增加计算开销,可能导致过拟合

多头注意力的主要改进方向

1. 混合专家注意力(Mixture of Experts Attention)

将多头注意力视为专家系统,每个头负责特定模式:

class MoEMultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_experts=None, top_k=2):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        num_experts = num_experts or num_heads
        self.num_experts = num_experts
        self.top_k = top_k
        
        # 专家选择网络
        self.router = nn.Linear(d_model, num_experts)
        
        # 专家头
        self.experts = nn.ModuleList([
            AttentionHead(d_model, self.head_dim) for _ in range(num_experts)
        ])
        
        self.output_proj = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask=None):
        batch_size, seq_len, _ = q.shape
        
        # 计算输入在每个专家上的路由分数
        router_input = torch.mean(q, dim=1)  # [batch_size, d_model]
        router_logits = self.router(router_input)  # [batch_size, num_experts]
        
        # 选择top-k专家
        expert_weights, expert_indices = torch.topk(router_logits, self.top_k, dim=-1)
        expert_weights = F.softmax(expert_weights, dim=-1)
        
        # 混合专家输出
        outputs = torch.zeros(batch_size, seq_len, self.d_model, device=q.device)
        
        for i in range(self.top_k):
            # 为每个样本选择合适的专家
            for b in range(batch_size):
                expert_idx = expert_indices[b, i].item()
                weight = expert_weights[b, i].unsqueeze(0).unsqueeze(0)
                expert_out = self.experts[expert_idx](q[b:b+1], k[b:b+1], v[b:b+1], mask)
                outputs[b:b+1] += weight * expert_out
        
        return self.output_proj(outputs)

2. 动态头部修剪(Dynamic Head Pruning)

动态确定每个序列中最重要的头,忽略不重要的头:

class DynamicHeadPruning(nn.Module):
    def __init__(self, d_model, num_heads, prune_ratio=0.5):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.prune_ratio = prune_ratio
        
        # 标准多头注意力组件
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        
        # 头部重要性预测器
        self.head_importance = nn.Linear(d_model, num_heads)
        
        self.out_proj = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask=None):
        batch_size = q.shape[0]
        
        # 预测头部重要性
        avg_q = q.mean(dim=1)  # [batch_size, d_model]
        head_scores = self.head_importance(avg_q)  # [batch_size, num_heads]
        
        # 选择最重要的头
        keep_heads = int(self.num_heads * (1 - self.prune_ratio))
        _, selected_heads = torch.topk(head_scores, keep_heads, dim=1)  # [batch_size, keep_heads]
        
        # 投影查询、键、值
        q = self.q_proj(q).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(k).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(v).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 计算所有头的注意力
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(1) == 0, -1e9)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)  # [batch_size, num_heads, seq_len, head_dim]
        
        # 创建批次特定的掩码,保留每个样本的选定头
        head_mask = torch.zeros(batch_size, self.num_heads, 1, 1, device=q.device)
        for b in range(batch_size):
            head_mask[b, selected_heads[b]] = 1.0
        
        # 应用头部掩码
        attn_output = attn_output * head_mask
        
        # 合并头部输出
        output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.out_proj(output)

3. 分组查询注意力(Grouped Query Attention)

将查询头分组,共享键值投影以减少计算开销:

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads, kv_groups=1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.kv_groups = kv_groups
        self.kv_heads = num_heads // kv_groups
        assert num_heads % kv_groups == 0, "头数必须能被组数整除"
        
        self.head_dim = d_model // num_heads
        
        # 查询投影 - 标准多头
        self.q_proj = nn.Linear(d_model, d_model)
        
        # 键值投影 - 分组减少参数
        self.k_proj = nn.Linear(d_model, self.kv_heads * self.head_dim)
        self.v_proj = nn.Linear(d_model, self.kv_heads * self.head_dim)
        
        self.out_proj = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask=None):
        batch_size, seq_len, _ = q.shape
        
        # 投影查询 - 完整头数
        q = self.q_proj(q).view(batch_size, seq_len, self.num_heads, self.head_dim)
        
        # 投影键值 - 减少的头数
        k = self.k_proj(k).view(batch_size, -1, self.kv_heads, self.head_dim)
        v = self.v_proj(v).view(batch_size, -1, self.kv_heads, self.head_dim)
        
        # 重复键值头以匹配查询头数
        k = k.repeat_interleave(self.kv_groups, dim=2)
        v = v.repeat_interleave(self.kv_groups, dim=2)
        
        # 调整维度顺序
        q = q.transpose(1, 2)  # [batch_size, num_heads, seq_len, head_dim]
        k = k.transpose(1, 2)  # [batch_size, num_heads, seq_len_k, head_dim]
        v = v.transpose(1, 2)  # [batch_size, num_heads, seq_len_v, head_dim]
        
        # 注意力计算
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(1) == 0, -1e9)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)
        
        # 合并头部输出
        output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.out_proj(output)

分组查询注意力是LLaMA、PaLM等大型语言模型的关键优化,显著减少计算和内存需求同时保持性能。

4. 自适应多头注意力(Adaptive Multi-Head Attention)

根据输入内容动态调整注意力参数:

class AdaptiveMultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, head_options=[4, 8, 16]):
        super().__init__()
        self.d_model = d_model
        self.max_heads = max(head_options)
        self.head_options = head_options
        
        # 头部计数预测器
        self.head_predictor = nn.Sequential(
            nn.Linear(d_model, 128),
            nn.ReLU(),
            nn.Linear(128, len(head_options))
        )
        
        # 投影矩阵 - 最大头数
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask=None):
        batch_size, seq_len, _ = q.shape
        
        # 预测最佳头数
        avg_q = q.mean(dim=1)  # [batch_size, d_model]
        head_logits = self.head_predictor(avg_q)  # [batch_size, len(head_options)]
        head_choice = torch.argmax(head_logits, dim=1)  # [batch_size]
        
        # 为每个样本选择头数
        results = []
        for b in range(batch_size):
            num_heads = self.head_options[head_choice[b].item()]
            head_dim = self.d_model // num_heads
            
            # 投影单个样本
            q_b = self.q_proj(q[b:b+1]).view(1, seq_len, num_heads, head_dim).transpose(1, 2)
            k_b = self.k_proj(k[b:b+1]).view(1, -1, num_heads, head_dim).transpose(1, 2)
            v_b = self.v_proj(v[b:b+1]).view(1, -1, num_heads, head_dim).transpose(1, 2)
            
            # 计算注意力
            scores = torch.matmul(q_b, k_b.transpose(-2, -1)) / math.sqrt(head_dim)
            if mask is not None:
                scores = scores.masked_fill(mask[b:b+1].unsqueeze(1) == 0, -1e9)
            attn_weights = F.softmax(scores, dim=-1)
            attn_output = torch.matmul(attn_weights, v_b)
            
            # 重塑输出
            output = attn_output.transpose(1, 2).contiguous().view(1, seq_len, self.d_model)
            results.append(output)
        
        # 合并批次结果
        combined = torch.cat(results, dim=0)
        return self.out_proj(combined)

多头注意力优化的实际应用

1. 头部重要性分析

识别和可视化不同头的重要性和功能:

def analyze_head_importance(model, dataset):
    """分析各注意力头的重要性"""
    # 设置头重要性初始值
    num_layers = model.config.num_hidden_layers
    num_heads = model.config.num_attention_heads
    head_importance = torch.zeros(num_layers, num_heads)
    
    # 收集梯度信息
    for batch in dataset:
        model.zero_grad()
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        
        # 收集注意力头的梯度
        for layer_idx in range(num_layers):
            layer = model.encoder.layer[layer_idx].attention.self
            q_grad = layer.query.weight.grad.abs().mean()
            k_grad = layer.key.weight.grad.abs().mean()
            v_grad = layer.value.weight.grad.abs().mean()
            
            # 头部重要性用梯度大小表示
            for head_idx in range(num_heads):
                head_start = head_idx * (model.config.hidden_size // num_heads)
                head_end = (head_idx + 1) * (model.config.hidden_size // num_heads)
                
                head_importance[layer_idx, head_idx] += (
                    q_grad[head_start:head_end].mean() +
                    k_grad[head_start:head_end].mean() +
                    v_grad[head_start:head_end].mean()
                ).item()
    
    # 归一化
    head_importance = head_importance / len(dataset)
    
    return head_importance

2. 头部剪枝实践

移除低重要性的头以提高效率:

def prune_heads(model, head_importance, prune_ratio=0.3):
    """基于重要性剪枝注意力头"""
    num_layers, num_heads = head_importance.shape
    heads_to_prune = {}
    
    # 对每一层,选择最不重要的头进行剪枝
    for layer_idx in range(num_layers):
        # 确定该层需要剪枝的头数
        num_to_prune = math.ceil(num_heads * prune_ratio)
        
        # 找出最不重要的头的索引
        layer_importance = head_importance[layer_idx]
        _, indices = torch.topk(layer_importance, num_heads - num_to_prune, largest=True)
        heads_to_prune[layer_idx] = set(range(num_heads)) - set(indices.tolist())
    
    # 执行剪枝
    for layer_idx, heads in heads_to_prune.items():
        model.encoder.layer[layer_idx].attention.prune_heads(list(heads))
    
    return model

3. 分组查询注意力在大模型中的应用

在GPT-4和LLaMA等模型中的实施例:

class LLaMAAttention(nn.Module):
    """LLaMA模型中的分组查询注意力实现示例"""
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads  # 键值头数,小于查询头数
        self.head_dim = self.hidden_size // self.num_heads
        
        # 投影矩阵
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        
        self.kv_heads_ratio = self.num_heads // self.num_kv_heads
        
    def forward(self, hidden_states, attention_mask=None, position_ids=None):
        batch_size, seq_length = hidden_states.shape[:2]
        
        # 投影查询、键、值
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        
        # 重塑
        query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim)
        key_states = key_states.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
        value_states = value_states.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
        
        # 重复键值头以匹配查询头
        if self.num_kv_heads != self.num_heads:
            key_states = key_states.repeat_interleave(self.kv_heads_ratio, dim=2)
            value_states = value_states.repeat_interleave(self.kv_heads_ratio, dim=2)
        
        # 转置用于注意力计算
        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)
        
        # 标准注意力计算
        attn_output = self._attention(query_states, key_states, value_states, attention_mask)
        
        # 重塑输出
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
        
        # 输出投影
        return self.o_proj(attn_output)
    
    def _attention(self, query, key, value, mask=None):
        # 注意力实现...
        pass

不同多头注意力改进方法的比较

改进方法参数效率计算效率性能影响实现复杂度适用场景
标准多头注意力基准线简单通用
头部剪枝轻微下降简单模型压缩
分组查询注意力几乎无变化中等大规模模型
混合专家注意力提升复杂复杂任务
自适应多头注意力提升复杂动态场景

多头注意力改进的实战建议

1. 选择合适的改进策略

根据模型规模和应用场景选择多头注意力改进方法:

  • 小型模型( < 100M参数):标准多头注意力或简单头部剪枝
  • 中型模型(100M - 10B参数):分组查询注意力,降低计算量
  • 大型模型( > 10B参数):分组查询注意力 + 专家混合,兼顾效率和性能

2. 注意头数与维度的权衡

def optimize_head_configuration(d_model):
    """为给定模型维度推荐合理的头数配置"""
    configurations = []
    
    # 测试不同的头数配置
    for num_heads in [1, 2, 4, 8, 12, 16, 24, 32, 64]:
        if d_model % num_heads != 0:
            continue
            
        head_dim = d_model // num_heads
        
        # 计算分组查询配置
        for kv_groups in [1, 2, 4, 8]:
            if num_heads % kv_groups != 0:
                continue
                
            kv_heads = num_heads // kv_groups
            
            # 参数量和计算量估计
            param_count = (
                d_model * d_model +  # Q投影
                d_model * (d_model // kv_groups) * 2 +  # K、V投影
                d_model * d_model  # 输出投影
            )
            
            # 加入推荐列表
            configurations.append({
                "d_model": d_model,
                "num_heads": num_heads,
                "head_dim": head_dim,
                "kv_groups": kv_groups,
                "kv_heads": kv_heads,
                "param_count": param_count,
                "param_reduction": 1 - (param_count / (4 * d_model * d_model)),
            })
    
    # 按参数量排序
    return sorted(configurations, key=lambda x: x["param_count"])

3. 多头注意力的可视化与分析

def visualize_attention_heads(model, text_input):
    """可视化多头注意力模式"""
    # 准备输入
    inputs = tokenizer(text_input, return_tensors="pt")
    
    # 获取注意力权重
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)
    
    # 提取所有层的注意力权重
    attentions = outputs.attentions  # tuple of (batch, num_heads, seq_len, seq_len)
    
    # 可视化每一层的每个头
    for layer_idx, layer_attention in enumerate(attentions):
        for head_idx in range(layer_attention.shape[1]):
            attn_weights = layer_attention[0, head_idx].numpy()
            
            # 绘制热力图
            plt.figure(figsize=(10, 8))
            plt.imshow(attn_weights, cmap="viridis")
            plt.title(f"Layer {layer_idx}, Head {head_idx}")
            plt.xlabel("Key position")
            plt.ylabel("Query position")
            plt.colorbar()
            plt.tight_layout()
            plt.show()
            
            # 分析头部专注模式
            analyze_head_pattern(attn_weights, layer_idx, head_idx)

小结

多头注意力的改进对大语言模型的效率和性能至关重要:

  1. 参数效率:分组查询注意力、头部剪枝等技术显著减少参数量
  2. 计算优化:改进的多头机制降低计算复杂度,支持更长序列处理
  3. 表达能力:专家混合和自适应机制增强模型捕获复杂模式的能力
  4. 工程实践:根据模型规模和应用场景选择适当的多头注意力变体

随着大语言模型规模的不断增长,多头注意力改进将继续发挥核心作用,推动模型效率和性能的提升。