Multi-Head Attention

预计学习时间:20分钟

多头注意力(Multi-Head Attention)是Transformer架构的核心创新之一,通过将注意力机制分为多个"头",使模型能够同时关注不同的表示子空间,提升模型的表达能力和性能。

多头注意力的基本原理

多头注意力通过将输入投影到多个子空间,并在每个子空间独立计算注意力,然后将结果合并:

其中:

每个注意力头的维度通常是模型维度的分数,如果有8个头,则每个头的维度是模型维度的1/8。这样保持了计算量大致不变,同时获得了多头的优势。

多头注意力的计算过程

1. 线性投影

将查询、键、值向量投影到多个子空间:

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个头的维度
        
        # 线性投影层
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.output_linear = nn.Linear(d_model, d_model)

2. 分割头部

将投影后的向量分割为多个头:

def split_heads(self, x, batch_size):
    """将张量分割成多头形式"""
    # x: [batch_size, seq_len, d_model]
    
    x = x.view(batch_size, -1, self.num_heads, self.d_k)
    # [batch_size, seq_len, num_heads, d_k]
    
    return x.transpose(1, 2)
    # [batch_size, num_heads, seq_len, d_k]

3. 计算注意力

在每个头上独立计算注意力:

def scaled_dot_product_attention(self, q, k, v, mask=None):
    """在每个头上计算缩放点积注意力"""
    # q, k, v: [batch_size, num_heads, seq_len, d_k]
    
    # 计算注意力分数
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
    # [batch_size, num_heads, seq_len_q, seq_len_k]
    
    # 应用掩码(如果提供)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # 计算注意力权重
    attention_weights = torch.softmax(scores, dim=-1)
    # [batch_size, num_heads, seq_len_q, seq_len_k]
    
    # 应用注意力权重
    output = torch.matmul(attention_weights, v)
    # [batch_size, num_heads, seq_len_q, d_k]
    
    return output, attention_weights

4. 合并头部

将多个头的输出合并:

def combine_heads(self, x, batch_size):
    """将多头形式的张量合并回原始形状"""
    # x: [batch_size, num_heads, seq_len, d_k]
    
    x = x.transpose(1, 2)
    # [batch_size, seq_len, num_heads, d_k]
    
    x = x.contiguous().view(batch_size, -1, self.d_model)
    # [batch_size, seq_len, d_model]
    
    return x

5. 最终的线性变换

应用最终的线性变换层:

def forward(self, q, k, v, mask=None):
    batch_size = q.size(0)
    
    # 线性投影
    q = self.q_linear(q)  # [batch_size, seq_len_q, d_model]
    k = self.k_linear(k)  # [batch_size, seq_len_k, d_model]
    v = self.v_linear(v)  # [batch_size, seq_len_v, d_model]
    
    # 分割成多个头
    q = self.split_heads(q, batch_size)  # [batch_size, num_heads, seq_len_q, d_k]
    k = self.split_heads(k, batch_size)  # [batch_size, num_heads, seq_len_k, d_k]
    v = self.split_heads(v, batch_size)  # [batch_size, num_heads, seq_len_v, d_k]
    
    # 应用缩放点积注意力
    attn_output, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)
    # attn_output: [batch_size, num_heads, seq_len_q, d_k]
    
    # 合并多个头
    output = self.combine_heads(attn_output, batch_size)  # [batch_size, seq_len_q, d_model]
    
    # 最终的线性层
    output = self.output_linear(output)  # [batch_size, seq_len_q, d_model]
    
    return output, attention_weights

多头注意力的完整实现

将上述步骤整合为一个完整的PyTorch模块:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # 线性投影层
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.output_linear = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def split_heads(self, x, batch_size):
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2)
    
    def combine_heads(self, x, batch_size):
        x = x.transpose(1, 2).contiguous()
        return x.view(batch_size, -1, self.d_model)
    
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        
        # 线性投影
        q = self.q_linear(q)
        k = self.k_linear(k)
        v = self.v_linear(v)
        
        # 分割成多个头
        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)
        
        # 计算注意力
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(1) == 0, -1e9)
        
        attention_weights = torch.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        attn_output = torch.matmul(attention_weights, v)
        
        # 合并多个头
        output = self.combine_heads(attn_output, batch_size)
        
        # 最终的线性层
        output = self.output_linear(output)
        
        return output, attention_weights

多头注意力机制示意图

多头注意力的优势

1. 增强表示能力

多头注意力允许模型同时关注不同的表示子空间:

def analyze_head_specialization(model, sentence):
    """分析不同注意力头的专门化程度"""
    tokens = tokenizer(sentence, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model(**tokens, output_attentions=True)
    
    # 提取某一层的注意力权重
    layer_attention = outputs.attentions[6]  # 第7层
    
    # 分析不同头关注的位置
    specializations = []
    
    for head_idx in range(layer_attention.shape[1]):
        head_attention = layer_attention[0, head_idx].numpy()
        
        # 计算注意力熵(分散程度)
        entropy = -np.sum(
            head_attention * np.log(head_attention + 1e-9), 
            axis=1
        ).mean()
        
        # 提取最强的注意力连接
        strongest_connections = []
        for i in range(head_attention.shape[0]):
            max_idx = np.argmax(head_attention[i])
            strongest_connections.append((i, max_idx))
        
        specializations.append({
            "head_idx": head_idx,
            "entropy": entropy,
            "strongest_connections": strongest_connections
        })
    
    return specializations

2. 提高稳定性

多头结构降低了单个注意力机制的随机性影响:

def test_attention_stability(model, sentence, num_runs=10, noise_level=0.01):
    """测试多头注意力对输入噪声的稳定性"""
    tokens = tokenizer(sentence, return_tensors="pt")
    input_embeds = model.get_input_embeddings()(tokens.input_ids)
    
    # 原始输出
    with torch.no_grad():
        original_output = model(inputs_embeds=input_embeds).last_hidden_state
    
    # 添加噪声后的输出
    noisy_outputs = []
    for _ in range(num_runs):
        noise = torch.randn_like(input_embeds) * noise_level
        noisy_embeds = input_embeds + noise
        
        with torch.no_grad():
            noisy_output = model(inputs_embeds=noisy_embeds).last_hidden_state
        
        noisy_outputs.append(noisy_output)
    
    # 计算稳定性指标(输出方差)
    variances = torch.stack([
        torch.mean((output - original_output)**2)
        for output in noisy_outputs
    ])
    
    return {
        "mean_variance": variances.mean().item(),
        "max_variance": variances.max().item()
    }

3. 并行计算效率

多个头可以并行计算,提高硬件利用率:

# 高效的多头注意力实现,利用批处理
def efficient_multi_head_attention(q, k, v, num_heads, d_model, mask=None):
    batch_size, seq_len, _ = q.shape
    d_k = d_model // num_heads
    
    # 线性投影和分割头 (一次性完成)
    q = q.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)
    k = k.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)
    v = v.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)
    
    # 批量计算注意力
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    
    if mask is not None:
        scores = scores.masked_fill(mask.unsqueeze(1) == 0, -1e9)
    
    weights = torch.softmax(scores, dim=-1)
    attn_output = torch.matmul(weights, v)
    
    # 合并头
    attn_output = attn_output.transpose(1, 2).contiguous().view(
        batch_size, seq_len, d_model
    )
    
    return attn_output

多头注意力的应用

1. 在Transformer编码器中的应用

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        # 多头自注意力
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        
        # 前馈网络
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        
        # 层归一化
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # 多头自注意力 + 残差连接
        attn_output, _ = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # 前馈网络 + 残差连接
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

2. 在Transformer解码器中的应用

class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        # 带掩码的多头自注意力
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        
        # 编码器-解码器多头注意力
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        
        # 前馈网络
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        
        # 层归一化
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, tgt_mask=None, src_mask=None):
        # 带掩码的多头自注意力 + 残差连接
        attn_output, _ = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # 编码器-解码器多头注意力 + 残差连接
        attn_output, _ = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        
        # 前馈网络 + 残差连接
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        
        return x

多头注意力的变体

1. 可分离的多头注意力

每个头使用不同参数,提高表达能力但增加参数量:

class SeparableMultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_model // num_heads
        
        # 每个头有独立的投影矩阵
        self.q_linears = nn.ModuleList([
            nn.Linear(d_model, self.d_k) for _ in range(num_heads)
        ])
        self.k_linears = nn.ModuleList([
            nn.Linear(d_model, self.d_k) for _ in range(num_heads)
        ])
        self.v_linears = nn.ModuleList([
            nn.Linear(d_model, self.d_k) for _ in range(num_heads)
        ])
        
        self.output_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        
        # 存储每个头的输出
        head_outputs = []
        
        for i in range(self.num_heads):
            # 独立投影
            q_i = self.q_linears[i](q)  # [batch_size, seq_len, d_k]
            k_i = self.k_linears[i](k)
            v_i = self.v_linears[i](v)
            
            # 计算注意力
            scores = torch.matmul(q_i, k_i.transpose(-2, -1)) / math.sqrt(self.d_k)
            
            if mask is not None:
                scores = scores.masked_fill(mask == 0, -1e9)
            
            weights = torch.softmax(scores, dim=-1)
            weights = self.dropout(weights)
            
            head_output = torch.matmul(weights, v_i)
            head_outputs.append(head_output)
        
        # 拼接多头输出
        output = torch.cat(head_outputs, dim=-1)
        
        # 最终线性变换
        output = self.output_linear(output)
        
        return output

2. 带相对位置编码的多头注意力

加入相对位置信息,增强序列建模能力:

class RelativeMultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, max_len=512, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # 线性投影
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.output_linear = nn.Linear(d_model, d_model)
        
        # 相对位置编码
        self.rel_pos_encoding = nn.Parameter(
            torch.randn(2 * max_len - 1, self.d_k)
        )
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        batch_size, seq_len, _ = q.shape
        
        # 线性投影
        q = self.q_linear(q).view(batch_size, seq_len, self.num_heads, self.d_k)
        k = self.k_linear(k).view(batch_size, seq_len, self.num_heads, self.d_k)
        v = self.v_linear(v).view(batch_size, seq_len, self.num_heads, self.d_k)
        
        # 转置维度
        q = q.transpose(1, 2)  # [batch_size, num_heads, seq_len, d_k]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # 内容内容注意力
        content_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # 内容位置注意力 (简化实现)
        positions = torch.arange(seq_len, device=q.device)
        rel_pos_indices = positions.unsqueeze(1) - positions.unsqueeze(0) + seq_len - 1
        rel_pos = self.rel_pos_encoding[rel_pos_indices]
        
        position_scores = torch.matmul(
            q.transpose(2, 3), 
            rel_pos.unsqueeze(0).unsqueeze(1).transpose(3, 4)
        ).squeeze(-1).transpose(2, 3)
        
        # 组合两种注意力分数
        scores = content_scores + position_scores
        
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(1) == 0, -1e9)
        
        weights = torch.softmax(scores, dim=-1)
        weights = self.dropout(weights)
        
        output = torch.matmul(weights, v)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.output_linear(output)
        
        return output

多头注意力的分析

注意力头的可视化

分析不同头关注的模式:

def visualize_attention_heads(model, text, layer_idx=6):
    """可视化多头注意力的注意力模式"""
    # 标记化输入
    tokens = tokenizer(text, return_tensors="pt")
    token_strings = tokenizer.convert_ids_to_tokens(tokens.input_ids[0])
    
    # 获取注意力权重
    with torch.no_grad():
        outputs = model(**tokens, output_attentions=True)
    
    # 提取指定层的注意力权重
    attention = outputs.attentions[layer_idx][0]  # [num_heads, seq_len, seq_len]
    
    # 创建注意力可视化
    fig, axes = plt.subplots(
        int(math.ceil(attention.shape[0] / 4)), 4, 
        figsize=(20, 5 * math.ceil(attention.shape[0] / 4))
    )
    axes = axes.flatten()
    
    for head_idx in range(attention.shape[0]):
        if head_idx < len(axes):
            ax = axes[head_idx]
            im = ax.imshow(attention[head_idx], cmap="viridis")
            ax.set_title(f"Head {head_idx}")
            ax.set_xticks(range(len(token_strings)))
            ax.set_yticks(range(len(token_strings)))
            ax.set_xticklabels(token_strings, rotation=90)
            ax.set_yticklabels(token_strings)
            plt.colorbar(im, ax=ax)
    
    plt.tight_layout()
    plt.show()

不同头关注的信息类型

多头注意力的不同头可能关注不同类型的信息:

头类型倾向于捕获常见观察结果
语法头语法依赖关系关注主语-谓语、介词-名词等
语义头语义相关性关注相关概念、同义词、上下文
位置头位置关系关注相邻token或固定距离的token
全局头全局信息几乎均匀地关注所有token

小结

多头注意力通过并行计算多个独立的注意力机制,提高了模型的表达能力和稳定性:

  1. 核心思想:将单个高维注意力拆分为多个并行的低维注意力
  2. 主要优势
    • 增强模型表达能力,捕获不同类型的依赖关系
    • 提高模型稳定性和鲁棒性
    • 实现高效的并行计算
  3. 实现关键:投影矩阵、头部分割与合并、缩放点积注意力计算
  4. 应用场景:Transformer编码器和解码器、多模态模型、图像处理等

多头注意力是大语言模型成功的关键因素之一,其设计理念也启发了众多注意力机制的变体和改进。随着模型规模增长,如何高效地实现和优化多头注意力成为了一个重要的研究方向。