Linear Attention改进

预计学习时间:25分钟

线性注意力(Linear Attention)是改进传统Transformer注意力机制计算复杂度的重要方向,将原始的二次方复杂度降低到线性复杂度,使模型能够处理更长的序列。本节介绍线性注意力的原理、主要方法和在大模型中的应用。

线性注意力的动机

传统的Softmax注意力计算复杂度为 ,其中 是序列长度:

随着序列长度增加,计算和内存需求呈二次方增长:

def standard_attention(q, k, v):
    """标准注意力计算 - O(n²)复杂度"""
    # q, k, v 形状: [batch_size, seq_len, dim]
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))  # [b, n, n]
    attn_weights = F.softmax(scores, dim=-1)  # [b, n, n] 
    output = torch.matmul(attn_weights, v)  # [b, n, dim]
    return output

# 计算8K长度序列的注意力需要约256MB内存(float32)仅用于存储注意力矩阵

这种复杂度限制了Transformer处理长序列的能力,线性注意力旨在解决这一问题。

线性注意力的核心思想

线性注意力通过数学变换,将矩阵乘法顺序重排,避免显式计算 的注意力矩阵:

其中 是一个特征映射函数。

线性注意力的关键在于利用矩阵乘法的结合律,将计算复杂度从 降低到 ,当 时,这是显著的优化。

主要的线性注意力方法

1. Linearized Attention

最基本的线性注意力使用正值特征映射:

def linear_attention(q, k, v, feature_map=None):
    """使用特征映射的线性注意力实现 - O(n)复杂度"""
    # 默认使用elu+1作为特征映射
    if feature_map is None:
        feature_map = lambda x: F.elu(x) + 1
    
    q = feature_map(q)  # [b, n, d]
    k = feature_map(k)  # [b, n, d]
    
    # 计算k和v的乘积 (d×d矩阵)
    kv = torch.bmm(k.transpose(-2, -1), v)  # [b, d, d]
    
    # 计算q和kv的乘积
    qkv = torch.bmm(q, kv)  # [b, n, d]
    
    # 计算归一化因子
    normalizer = torch.bmm(q, k.sum(dim=1).unsqueeze(-1))  # [b, n, 1]
    
    return qkv / (normalizer + 1e-9)  # 防止除零

2. Performer (FAVOR+)

Google研究的一种线性注意力方法,使用随机投影近似传统注意力:

def performer_attention(q, k, v, projection_dim=256, orthogonal=True, eps=1e-4):
    """Performer的FAVOR+算法实现"""
    # 获取维度信息
    batch_size, seq_len, d_model = q.shape
    
    # 创建随机投影矩阵
    if orthogonal:
        # 使用正交随机特征
        projection_matrix = create_orthogonal_matrix(d_model, projection_dim)
    else:
        # 使用iid高斯向量
        projection_matrix = torch.randn(d_model, projection_dim) / math.sqrt(projection_dim)
    
    projection_matrix = projection_matrix.to(q.device)
    
    # 计算随机特征
    q_prime = torch.exp(q @ projection_matrix) / math.sqrt(projection_dim)
    k_prime = torch.exp(k @ projection_matrix) / math.sqrt(projection_dim)
    
    # 线性计算注意力
    kv = torch.bmm(k_prime.transpose(-2, -1), v)
    qkv = torch.bmm(q_prime, kv)
    
    # 归一化
    normalizer = torch.bmm(q_prime, k_prime.sum(dim=1).unsqueeze(-1))
    
    return qkv / (normalizer + eps)

3. Linear Transformer

简化版线性注意力,使用固定的特征映射:

def linear_transformer_attention(q, k, v):
    """Linear Transformer的注意力实现"""
    # 使用elu+1特征映射
    def phi(x):
        return F.elu(x) + 1
    
    q = phi(q)  # [b, n, d]
    k = phi(k)  # [b, n, d]
    
    # 线性复杂度计算
    kv = torch.einsum('bnd,bne->bde', k, v)  # [b, d, d_v]
    qkv = torch.einsum('bmd,bde->bme', q, kv)  # [b, m, d_v]
    
    # 归一化
    normalizer = torch.einsum('bmd,bd->bm', q, k.sum(dim=1))  # [b, m]
    
    return qkv / normalizer.unsqueeze(-1)  # [b, m, d_v]

4. Efficient Attention

Facebook (Meta) 研究的高效注意力实现:

def efficient_attention(q, k, v, dropout=0.1, chunk_size=128):
    """分块计算的高效线性注意力"""
    # 获取维度
    batch_size, seq_len, num_heads, head_dim = q.shape
    
    # 重塑为4D张量
    q = q.view(batch_size, num_heads, seq_len, head_dim)
    k = k.view(batch_size, num_heads, seq_len, head_dim)
    v = v.view(batch_size, num_heads, seq_len, head_dim)
    
    # 计算注意力,分块处理以节省内存
    outputs = []
    for i in range(0, seq_len, chunk_size):
        # 获取当前块
        q_chunk = q[:, :, i:i+chunk_size]
        chunk_len = q_chunk.size(2)
        
        # 计算当前块的输出
        scores = torch.matmul(q_chunk, k.transpose(-2, -1)) / math.sqrt(head_dim)
        weights = F.softmax(scores, dim=-1)
        if dropout > 0:
            weights = F.dropout(weights, p=dropout)
        chunk_output = torch.matmul(weights, v)
        outputs.append(chunk_output)
    
    # 合并所有块的输出
    output = torch.cat(outputs, dim=2)
    
    # 重塑回原始形状
    output = output.view(batch_size, seq_len, num_heads * head_dim)
    
    return output

核函数在线性注意力中的应用

线性注意力可以通过核方法解释,其中注意力被视为核函数:

常见的核函数选择:

1. 正核函数 (Positive Kernels)

def elu_kernel(x):
    """ELU+1核函数"""
    return F.elu(x) + 1

def softmax_kernel(x, y, dim):
    """Softmax可以被作为一种核函数"""
    # 计算点积
    dots = torch.matmul(x, y.transpose(-2, -1))
    # 应用Softmax
    return F.softmax(dots, dim=dim)

2. 随机核函数

Performer使用随机特征近似核函数:

def random_fourier_features(x, n_features=256, sigma=1.0):
    """随机傅立叶特征"""
    # 创建随机投影矩阵
    proj = torch.randn(x.size(-1), n_features) * sigma
    proj = proj.to(x.device)
    
    # 投影输入向量
    x_proj = x @ proj
    
    # 应用激活函数
    x_feats = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
    
    return x_feats / math.sqrt(n_features)

线性注意力的理论分析

1. 误差分析

线性注意力与原始注意力的近似误差:

def approximation_error(q, k, v, linear_attn_fn, std_attn_fn):
    """计算线性注意力与标准注意力的近似误差"""
    # 计算标准注意力
    std_output = std_attn_fn(q, k, v)
    
    # 计算线性注意力
    linear_output = linear_attn_fn(q, k, v)
    
    # 计算误差
    error = torch.norm(std_output - linear_output, p='fro') / torch.norm(std_output, p='fro')
    
    return error.item()

2. 计算复杂度比较

def complexity_comparison(seq_lengths, d_model=512):
    """比较不同序列长度下的计算复杂度"""
    standard_complexity = [seq_len**2 * d_model for seq_len in seq_lengths]
    linear_complexity = [seq_len * d_model**2 for seq_len in seq_lengths]
    
    # 计算加速比
    speedup = [std / lin for std, lin in zip(standard_complexity, linear_complexity)]
    
    return {
        "序列长度": seq_lengths,
        "标准注意力FLOPs": standard_complexity,
        "线性注意力FLOPs": linear_complexity,
        "加速比": speedup
    }

线性注意力的优化实践

1. 内存优化

def memory_efficient_linear_attention(q, k, v, chunk_size=1024):
    """使用分块策略节省内存的线性注意力"""
    # 获取维度
    batch_size, seq_len, head_dim = q.shape
    
    # 应用特征映射
    q = F.elu(q) + 1
    k = F.elu(k) + 1
    
    # 初始化KV累加器
    kv = torch.zeros(batch_size, head_dim, head_dim, device=q.device)
    
    # 分块累积K^T·V
    for i in range(0, seq_len, chunk_size):
        k_chunk = k[:, i:i+chunk_size]
        v_chunk = v[:, i:i+chunk_size]
        kv += torch.bmm(k_chunk.transpose(-2, -1), v_chunk)
    
    # 计算输出
    output = torch.bmm(q, kv)
    
    # 计算归一化因子
    k_sum = k.sum(dim=1).unsqueeze(-1)  # [b, d, 1]
    normalizer = torch.bmm(q, k_sum)  # [b, n, 1]
    
    return output / (normalizer + 1e-6)

2. 混合注意力策略

结合线性和标准注意力的优点:

class HybridAttention(nn.Module):
    """混合使用线性和标准注意力"""
    def __init__(self, d_model, n_heads, linear_threshold=512):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.linear_threshold = linear_threshold
        
        # 投影层
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask=None):
        batch_size, seq_len, _ = q.shape
        
        # 投影并重塑
        q = self.q_proj(q).view(batch_size, seq_len, self.n_heads, self.head_dim)
        k = self.k_proj(k).view(batch_size, seq_len, self.n_heads, self.head_dim)
        v = self.v_proj(v).view(batch_size, seq_len, self.n_heads, self.head_dim)
        
        # 选择注意力类型
        if seq_len > self.linear_threshold:
            # 使用线性注意力
            output = self._linear_attention(q, k, v)
        else:
            # 使用标准注意力
            output = self._standard_attention(q, k, v, mask)
            
        # 投影输出
        output = output.reshape(batch_size, seq_len, self.d_model)
        output = self.out_proj(output)
        
        return output
        
    def _standard_attention(self, q, k, v, mask=None):
        # 标准注意力实现...
        pass
        
    def _linear_attention(self, q, k, v):
        # 线性注意力实现...
        pass

大模型中的线性注意力应用

1. Longformer

Facebook开发的长文本处理模型:

def longformer_attention(q, k, v, local_attn_window=512, global_tokens=None):
    """Longformer的局部+全局稀疏注意力模式"""
    batch_size, seq_len, num_heads, head_dim = q.shape
    
    # 初始化输出
    output = torch.zeros_like(q)
    
    # 局部滑动窗口注意力
    for i in range(seq_len):
        # 确定窗口范围
        window_start = max(0, i - local_attn_window // 2)
        window_end = min(seq_len, i + local_attn_window // 2 + 1)
        
        # 计算局部注意力
        local_q = q[:, i:i+1]
        local_k = k[:, window_start:window_end]
        local_v = v[:, window_start:window_end]
        
        # 标准注意力计算
        scores = torch.matmul(local_q, local_k.transpose(-2, -1)) / math.sqrt(head_dim)
        weights = F.softmax(scores, dim=-1)
        local_output = torch.matmul(weights, local_v)
        
        output[:, i:i+1] = local_output
    
    # 全局注意力 (如果指定)
    if global_tokens is not None:
        for idx in global_tokens:
            # 计算全局token的注意力
            global_q = q[:, idx:idx+1]
            scores = torch.matmul(global_q, k.transpose(-2, -1)) / math.sqrt(head_dim)
            weights = F.softmax(scores, dim=-1)
            global_output = torch.matmul(weights, v)
            
            # 更新输出
            output[:, idx:idx+1] = global_output
    
    return output

2. Reformer

Google开发的基于LSH(局部敏感哈希)的高效注意力:

def lsh_attention(q, k, v, n_buckets=64, n_rounds=4):
    """基于LSH的局部敏感哈希注意力"""
    batch_size, seq_len, d_model = q.shape
    device = q.device
    
    # 使用相同的键和查询进行哈希
    x = q  # [b, n, d]
    
    # 进行多轮哈希以提高准确性
    outputs = []
    for r in range(n_rounds):
        # 创建随机投影向量
        random_rotations = torch.randn(d_model, n_buckets // 2, device=device)
        
        # 计算哈希桶
        rotated_x = torch.einsum('bnd,df->bnf', x, random_rotations)
        bucket_indices = torch.argmax(torch.cat([rotated_x, -rotated_x], dim=-1), dim=-1)
        
        # 按桶索引排序
        sorted_indices = bucket_indices.argsort(dim=1)
        x_sorted = batched_index_select(x, sorted_indices)
        bucket_indices_sorted = batched_index_select(bucket_indices, sorted_indices)
        
        # 找到桶的边界
        boundaries = (bucket_indices_sorted[:, 1:] != bucket_indices_sorted[:, :-1]).long()
        segment_ids = torch.cat([torch.zeros((batch_size, 1), device=device, dtype=torch.long),
                               torch.cumsum(boundaries, dim=1)], dim=1)
        
        # 计算分块注意力
        # ...实现分块注意力计算...
        
        # 将结果按原始顺序排列
        inv_sorted_indices = torch.argsort(sorted_indices, dim=1)
        round_output = batched_index_select(output_sorted, inv_sorted_indices)
        outputs.append(round_output)
    
    # 平均多轮结果
    output = torch.stack(outputs).mean(dim=0)
    
    return output

3. xFormers中的实现

Facebook开发的高效Transformer库,整合了多种线性注意力方法:

def xformers_linear_attention(q, k, v, method="favor"):
    """xFormers库中的线性注意力实现示例"""
    if method == "favor":
        # FAVOR+方法
        return favor_attention(q, k, v)
    elif method == "kernel":
        # 核函数方法
        return kernel_attention(q, k, v)
    elif method == "nystrom":
        # Nyström方法
        return nystrom_attention(q, k, v)
    else:
        raise ValueError(f"不支持的线性注意力方法: {method}")

线性注意力与标准注意力的比较

特性标准Softmax注意力线性注意力
计算复杂度
内存需求
序列长度支持有限 (通常 < 4K)更长 (可达数十万)
表达能力相对较弱
全局依赖建模直接建模间接近似
适用场景中短文本,精确建模长文本,高效处理
实施难度相对简单较复杂

小结

线性注意力是大型语言模型处理长序列的重要技术突破:

  1. 核心优势:将传统注意力的 复杂度降低到 ,显著提高长序列处理能力
  2. 主要方法:包括Linearized Attention、Performer、Linear Transformer等,通过矩阵分解或核函数近似实现线性复杂度
  3. 应用场景:特别适用于长文本处理、文档分析、长对话历史等场景
  4. 权衡取舍:线性注意力通常以略微降低表达能力为代价,获得更高的计算效率

随着大模型应用场景逐渐拓展到更长的输入场景,线性注意力及其变体将发挥越来越重要的作用,成为Transformer架构发展的关键方向之一。