Self-Attention

预计学习时间:30分钟

**自注意力(Self-Attention)**是Transformer架构的核心机制,它允许模型在处理序列时关注序列中的所有位置,为每个位置计算加权上下文表示。通过学习输入序列中元素之间的关联强度,自注意力机制能有效捕获长距离依赖关系,大幅提升模型对序列数据的处理能力。

自注意力机制的基本原理

自注意力的核心思想是让序列中的每个元素能够"看到"整个序列,并有选择性地关注相关部分:

注意力计算流程

自注意力通过以下步骤计算:

  1. 线性投影:将输入向量转换为查询(Q)、键(K)、值(V)三种表示
  2. 注意力分数:计算查询与所有键的相似度作为注意力分数
  3. 权重归一化:对分数应用softmax函数获得概率分布
  4. 加权聚合:使用归一化权重对值向量进行加权求和

自注意力计算流程

# 自注意力机制的基本实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelfAttention(nn.Module):
    def __init__(self, embed_size):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        
        # 线性变换矩阵
        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)
        
        # 输出投影
        self.fc_out = nn.Linear(embed_size, embed_size)
        
    def forward(self, x, mask=None):
        # x: [batch_size, seq_len, embed_size]
        batch_size = x.shape[0]
        seq_length = x.shape[1]
        
        # 1. 获取查询、键、值向量
        queries = self.query(x)  # [batch_size, seq_len, embed_size]
        keys = self.key(x)       # [batch_size, seq_len, embed_size]
        values = self.value(x)   # [batch_size, seq_len, embed_size]
        
        # 2. 计算注意力分数 (点积注意力)
        # 使用爱因斯坦求和约定进行批量矩阵乘法
        energy = torch.einsum("nqd,nkd->nqk", [queries, keys])
        # energy: [batch_size, seq_len, seq_len]
        
        # 3. 缩放注意力分数
        energy = energy / math.sqrt(self.embed_size)
        
        # 4. 掩蔽填充位置(可选)
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        
        # 5. Softmax归一化获取注意力权重
        attention = F.softmax(energy, dim=2)  # [batch_size, seq_len, seq_len]
        
        # 6. 加权聚合值向量
        out = torch.einsum("nqk,nkd->nqd", [attention, values])
        # out: [batch_size, seq_len, embed_size]
        
        # 7. 输出线性变换
        out = self.fc_out(out)
        
        return out, attention

在实践中,自注意力的计算复杂度是O(n²),其中n是序列长度。这意味着对于长序列,计算成本会迅速增加,成为Transformer模型的主要瓶颈之一。

自注意力的数学表达

标准自注意力机制可以用以下数学公式表示:

其中:

  • 分别是查询、键和值矩阵
  • 是键向量的维度
  • 是缩放因子,防止点积过大导致softmax梯度消失

这种形式的注意力被称为标准点积注意力(Scaled Dot-Product Attention)。

"自注意力能够学习到输入序列中所有位置之间的依赖关系,无论它们之间的距离有多远,这使得它在捕获长距离依赖方面比RNN更有优势。" — Vaswani et al.

各种自注意力变体

自注意力机制有多种变体和扩展,各有特点:

1. 点积注意力与加性注意力

注意力类型计算方式特点
点积注意力计算高效,尤其在高维向量上
加性注意力更适合低维向量,数值稳定性更好

2. 硬注意力与软注意力

  • 软注意力(Soft Attention):分配连续的注意力权重,完全可微
  • 硬注意力(Hard Attention):仅关注单个位置,使用采样,需要强化学习

3. 稀疏注意力变体

为了解决自注意力的计算复杂度问题,多种稀疏注意力机制被提出:

# 简单的局部注意力实现(仅关注上下文窗口)
def local_attention(queries, keys, values, window_size=5, mask=None):
    batch_size, seq_len, dim = queries.shape
    
    # 初始化输出
    outputs = torch.zeros_like(queries)
    
    # 对每个位置计算局部注意力
    for i in range(seq_len):
        # 确定局部窗口范围
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2 + 1)
        
        # 获取局部键值
        local_keys = keys[:, start:end, :]
        local_values = values[:, start:end, :]
        
        # 当前查询
        q = queries[:, i:i+1, :]  # [batch_size, 1, dim]
        
        # 计算注意力分数
        scores = torch.bmm(q, local_keys.transpose(1, 2)) / math.sqrt(dim)
        
        # 应用掩码(如果有)
        if mask is not None:
            local_mask = mask[:, i:i+1, start:end]
            scores = scores.masked_fill(local_mask == 0, float("-1e20"))
        
        # Softmax获取权重
        attn_weights = F.softmax(scores, dim=2)  # [batch_size, 1, window_size]
        
        # 加权求和
        context = torch.bmm(attn_weights, local_values)  # [batch_size, 1, dim]
        outputs[:, i:i+1, :] = context
    
    return outputs

自注意力的优势与局限

自注意力相比传统序列模型带来多方面优势:

主要优势

  1. 全局依赖捕获:直接建立任意距离的token间联系
  2. 并行计算:所有位置可以同时计算
  3. 直观可解释:注意力权重提供了模型决策的可视化解释
  4. 灵活适应:同样的机制适用于不同长度的序列

关键局限

  1. 二次计算复杂度:计算需求随序列长度平方增长
  2. 位置信息缺失:需要额外的位置编码提供顺序信息
  3. 高内存需求:需要存储完整的注意力矩阵
# 可视化自注意力权重
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

def visualize_attention(tokens, attention_weights):
    """可视化自注意力权重"""
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        attention_weights,
        xticklabels=tokens,
        yticklabels=tokens,
        cmap="YlGnBu",
        annot=False
    )
    plt.title("Self-Attention Weights")
    plt.xlabel("Key Tokens")
    plt.ylabel("Query Tokens")
    plt.tight_layout()
    return plt

# 示例使用
tokens = ["The", "cat", "sits", "on", "the", "mat", "."]
# 随机生成的注意力权重矩阵
random_attention = np.random.rand(len(tokens), len(tokens))
random_attention = random_attention / random_attention.sum(axis=1, keepdims=True)
attention_plot = visualize_attention(tokens, random_attention)

自注意力在不同任务中的应用

自注意力已成为多种NLP和计算机视觉任务的关键组件:

语言建模

  • 捕获词之间的上下文依赖关系
  • 处理长距离语法关系和指代消解

机器翻译

  • 在编码器中建立源语言词之间的关系
  • 在编码器-解码器注意力中关联目标语言和源语言

文本摘要

  • 识别文本中的关键信息
  • 学习句子间的关系

计算机视觉

  • 处理图像作为"像素序列"
  • 捕获图像不同区域之间的关系

实现中的关键技巧

在实际实现自注意力时,一些技巧至关重要:

1. 掩码操作

  • 填充掩码(Padding Mask):防止模型关注填充位置
  • 未来掩码(Future Mask):在自回归生成中防止信息泄露

2. 注意力丢弃(Attention Dropout)

在softmax后应用Dropout,增加模型鲁棒性:

# 带Dropout的自注意力
def attention_with_dropout(query, key, value, mask=None, dropout=None):
    "点积注意力实现,带掩码和dropout"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    attn = F.softmax(scores, dim=-1)
    
    if dropout is not None:
        attn = dropout(attn)
    
    return torch.matmul(attn, value), attn

自注意力机制作为Transformer和现代大型语言模型的核心部件,彻底改变了序列数据建模的方式。通过理解其工作原理和优化手段,我们能更好地理解和利用这些强大模型的内部机制。