Self Attention

预计学习时间:25分钟

自注意力(Self Attention)是Transformer架构的核心组件,能够捕获序列内部的依赖关系。与传统的RNN不同,自注意力通过直接计算序列中所有位置之间的关联,实现了并行计算和捕获长距离依赖的能力。

自注意力的基本原理

自注意力机制的核心思想是让序列中的每个元素都能"关注"序列中的所有其他元素,并根据关联程度分配不同的注意力权重:

其中:

  • (查询):当前位置的表示,用于与其他位置进行匹配
  • (键):所有位置的表示,用于被查询匹配
  • (值):所有位置的表示,用于信息聚合
  • :键向量的维度,用于缩放点积以避免梯度消失

在自注意力中,都是由同一个输入序列通过不同的线性变换得到的。

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelfAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        
        # 线性变换矩阵
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        
        self.out = nn.Linear(d_model, d_model)
        
    def forward(self, x, mask=None):
        # x: [batch_size, seq_len, d_model]
        batch_size, seq_len, _ = x.size()
        
        # 线性变换
        q = self.q_linear(x)  # [batch_size, seq_len, d_model]
        k = self.k_linear(x)  # [batch_size, seq_len, d_model]
        v = self.v_linear(x)  # [batch_size, seq_len, d_model]
        
        # 计算注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_model)
        # [batch_size, seq_len, seq_len]
        
        # 应用掩码(如果提供)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # 应用softmax得到注意力权重
        attn_weights = F.softmax(scores, dim=-1)  # [batch_size, seq_len, seq_len]
        
        # 加权汇总值向量
        output = torch.matmul(attn_weights, v)  # [batch_size, seq_len, d_model]
        
        return self.out(output)

自注意力机制示意图

自注意力的计算过程

1. 输入嵌入和线性投影

将输入序列转换为嵌入向量后,通过线性变换生成查询(Q)、键(K)和值(V):

def generate_qkv(input_embeddings, w_q, w_k, w_v):
    """生成查询、键、值向量"""
    # input_embeddings: [batch_size, seq_len, d_model]
    # w_q, w_k, w_v: [d_model, d_model]
    
    q = torch.matmul(input_embeddings, w_q)  # [batch_size, seq_len, d_model]
    k = torch.matmul(input_embeddings, w_k)  # [batch_size, seq_len, d_model]
    v = torch.matmul(input_embeddings, w_v)  # [batch_size, seq_len, d_model]
    
    return q, k, v

2. 计算注意力分数

计算查询与所有键的点积,并进行缩放:

def compute_attention_scores(q, k, scale=True):
    """计算注意力分数矩阵"""
    # q, k: [batch_size, seq_len, d_model]
    
    # 计算点积
    scores = torch.matmul(q, k.transpose(-2, -1))  # [batch_size, seq_len, seq_len]
    
    # 缩放点积
    if scale:
        d_k = q.size(-1)
        scores = scores / math.sqrt(d_k)
    
    return scores

3. 应用掩码(可选)

对于需要掩盖的位置(如未来位置或填充位置),将分数设为负无穷:

def apply_mask(scores, mask):
    """应用注意力掩码"""
    # scores: [batch_size, seq_len, seq_len]
    # mask: [batch_size, seq_len, seq_len] 二进制掩码
    
    masked_scores = scores.masked_fill(mask == 0, -1e9)
    return masked_scores

掩码在自注意力中至关重要,特别是在解码器中,我们需要防止模型看到未来的信息。

4. Softmax归一化

将注意力分数转换为概率分布:

def normalize_scores(scores):
    """将分数归一化为概率分布"""
    # scores: [batch_size, seq_len, seq_len]
    
    attention_weights = F.softmax(scores, dim=-1)
    return attention_weights

5. 加权汇总

使用注意力权重对值向量进行加权求和:

def weighted_aggregation(attention_weights, v):
    """基于注意力权重聚合值向量"""
    # attention_weights: [batch_size, seq_len, seq_len]
    # v: [batch_size, seq_len, d_model]
    
    output = torch.matmul(attention_weights, v)  # [batch_size, seq_len, d_model]
    return output

自注意力的特点

优势

  1. 并行计算:与RNN不同,自注意力可以并行计算所有位置,提高训练效率

  2. 捕获长距离依赖:每个位置都可以直接关注任何其他位置,无论距离多远

def analyze_attention_span(attention_weights):
    """分析自注意力的关注范围"""
    # attention_weights: [batch_size, seq_len, seq_len]
    
    # 计算平均注意力距离
    batch_size, seq_len, _ = attention_weights.shape
    
    # 创建距离矩阵
    positions = torch.arange(seq_len).unsqueeze(0).unsqueeze(2)  # [1, seq_len, 1]
    positions_t = positions.transpose(1, 2)  # [1, 1, seq_len]
    distance = torch.abs(positions - positions_t)  # [1, seq_len, seq_len]
    
    # 计算加权平均距离
    avg_distance = torch.sum(attention_weights * distance, dim=2)  # [batch_size, seq_len]
    
    return avg_distance.mean()
  1. 表示能力强:可以捕获复杂的序列内部关系,包括语法和语义依赖

局限

  1. 二次计算复杂度:注意力矩阵大小为 ,随序列长度二次增长

  2. 没有位置信息:需要额外的位置编码提供顺序信息

自注意力本身是位置无关的,即打乱序列的顺序不会改变结果。因此,必须引入位置编码以提供顺序信息。

变体和扩展

1. 带掩码的自注意力

在解码器中使用,确保预测时只使用当前及之前的信息:

def causal_self_attention(x):
    """带因果掩码的自注意力,用于自回归生成"""
    batch_size, seq_len, d_model = x.shape
    
    # 创建下三角掩码
    mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)
    
    # 计算自注意力,应用掩码
    q = self.q_linear(x)
    k = self.k_linear(x)
    v = self.v_linear(x)
    
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_model)
    scores = scores.masked_fill(mask == 0, -1e9)
    
    weights = F.softmax(scores, dim=-1)
    output = torch.matmul(weights, v)
    
    return output

2. 局部敏感哈希注意力

减少计算复杂度的变体,适用于长序列:

def lsh_attention(q, k, v, n_buckets=256, n_hashes=8):
    """局部敏感哈希注意力,降低长序列的计算复杂度"""
    # 简化示例
    batch_size, seq_len, d_model = q.shape
    
    # 哈希函数(简化表示)
    def hash_vectors(vectors, n_buckets, n_hashes):
        # 在实际实现中,这应该使用随机投影等方法
        projections = torch.randn(n_hashes, d_model, n_buckets // 2)
        projected = torch.matmul(vectors.unsqueeze(1), projections)
        hashed = torch.argmax(torch.cat([projected, -projected], dim=-1), dim=-1)
        return hashed
    
    # 对查询和键进行哈希
    q_buckets = hash_vectors(q, n_buckets, n_hashes)  # [batch_size, n_hashes, seq_len]
    k_buckets = hash_vectors(k, n_buckets, n_hashes)
    
    # 后续计算:只计算同一桶内的注意力,此处省略复杂实现...
    
    return output

自注意力的应用案例

1. 自然语言处理

自注意力能够捕获语法和语义依赖关系:

def analyze_syntactic_attention(model, sentence):
    """分析自注意力如何捕获句法结构"""
    # 对句子进行标记化
    tokens = tokenizer(sentence, return_tensors="pt")
    
    # 前向传播,获取注意力权重
    with torch.no_grad():
        outputs = model(**tokens, output_attentions=True)
    
    # 取出特定层的注意力权重
    layer_attentions = outputs.attentions[6]  # 第7层注意力
    
    # 可视化句法关系
    for head_idx in [0, 3, 5]:  # 选择几个特定的注意力头
        attn = layer_attentions[0, head_idx].numpy()
        
        # 绘制热力图
        plt.figure(figsize=(10, 8))
        plt.imshow(attn, cmap="viridis")
        plt.xticks(range(len(tokens)), tokens, rotation=90)
        plt.yticks(range(len(tokens)), tokens)
        plt.title(f"Head {head_idx} Attention")
        plt.colorbar()
        plt.tight_layout()
        plt.show()

2. 计算机视觉

自注意力用于捕获图像中的长距离依赖:

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, 
                 embed_dim=768, depth=12, num_heads=12):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        
        num_patches = (img_size // patch_size) ** 2
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads) for _ in range(depth)
        ])
        
    def forward(self, x):
        # x: [batch_size, channels, height, width]
        x = self.patch_embed(x)  # [batch_size, num_patches, embed_dim]
        
        # 添加分类令牌
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        
        # 添加位置编码
        x = x + self.pos_embed
        
        # 应用Transformer块
        for block in self.blocks:
            x = block(x)
        
        # 使用[CLS]令牌的表示进行分类
        return x[:, 0]

自注意力的效率和优化

计算复杂度分析

标准自注意力的计算复杂度:

操作时间复杂度空间复杂度
QK^T计算O(n²d)O(n²)
SoftmaxO(n²)O(n²)
乘以VO(n²d)O(nd)
总计O(n²d)O(n²)

其中n是序列长度,d是隐藏维度。

优化方法

对于长序列处理,可以采用以下优化方法:

def efficient_attention(q, k, v, chunk_size=128):
    """分块计算自注意力以优化内存使用"""
    batch_size, seq_len, d_model = q.shape
    
    # 初始化输出
    output = torch.zeros_like(q)
    
    # 分块计算
    for i in range(0, seq_len, chunk_size):
        end_idx = min(i + chunk_size, seq_len)
        
        # 当前块的查询
        q_chunk = q[:, i:end_idx]
        
        # 计算当前块与所有键的注意力
        chunk_scores = torch.matmul(q_chunk, k.transpose(-2, -1)) / math.sqrt(d_model)
        chunk_weights = F.softmax(chunk_scores, dim=-1)
        
        # 聚合值向量
        chunk_output = torch.matmul(chunk_weights, v)
        
        # 更新输出
        output[:, i:end_idx] = chunk_output
    
    return output

小结

自注意力机制是Transformer架构的基础,通过直接计算序列内部的依赖关系,实现了捕获长距离依赖和并行计算的能力:

  1. 核心优势:全局上下文建模、捕获长距离依赖、并行计算
  2. 关键计算步骤:生成查询/键/值、计算注意力分数、归一化、加权聚合
  3. 主要挑战:二次计算复杂度、缺乏内在的位置信息
  4. 优化方向:降低计算复杂度、提高长序列处理能力

随着大语言模型的发展,自注意力的效率优化变得越来越重要,各种稀疏注意力和线性复杂度注意力变体也应运而生,成为推动大规模模型发展的关键技术。