Linear Attention改进
预计学习时间:25分钟
线性注意力(Linear Attention)是改进传统Transformer注意力机制计算复杂度的重要方向,将原始的二次方复杂度降低到线性复杂度,使模型能够处理更长的序列。本节介绍线性注意力的原理、主要方法和在大模型中的应用。
线性注意力的动机
传统的Softmax注意力计算复杂度为
随着序列长度增加,计算和内存需求呈二次方增长:
def standard_attention(q, k, v):
"""标准注意力计算 - O(n²)复杂度"""
# q, k, v 形状: [batch_size, seq_len, dim]
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)) # [b, n, n]
attn_weights = F.softmax(scores, dim=-1) # [b, n, n]
output = torch.matmul(attn_weights, v) # [b, n, dim]
return output
# 计算8K长度序列的注意力需要约256MB内存(float32)仅用于存储注意力矩阵
这种复杂度限制了Transformer处理长序列的能力,线性注意力旨在解决这一问题。
线性注意力的核心思想
线性注意力通过数学变换,将矩阵乘法顺序重排,避免显式计算
其中
线性注意力的关键在于利用矩阵乘法的结合律,将计算复杂度从
主要的线性注意力方法
1. Linearized Attention
最基本的线性注意力使用正值特征映射:
def linear_attention(q, k, v, feature_map=None):
"""使用特征映射的线性注意力实现 - O(n)复杂度"""
# 默认使用elu+1作为特征映射
if feature_map is None:
feature_map = lambda x: F.elu(x) + 1
q = feature_map(q) # [b, n, d]
k = feature_map(k) # [b, n, d]
# 计算k和v的乘积 (d×d矩阵)
kv = torch.bmm(k.transpose(-2, -1), v) # [b, d, d]
# 计算q和kv的乘积
qkv = torch.bmm(q, kv) # [b, n, d]
# 计算归一化因子
normalizer = torch.bmm(q, k.sum(dim=1).unsqueeze(-1)) # [b, n, 1]
return qkv / (normalizer + 1e-9) # 防止除零
2. Performer (FAVOR+)
Google研究的一种线性注意力方法,使用随机投影近似传统注意力:
def performer_attention(q, k, v, projection_dim=256, orthogonal=True, eps=1e-4):
"""Performer的FAVOR+算法实现"""
# 获取维度信息
batch_size, seq_len, d_model = q.shape
# 创建随机投影矩阵
if orthogonal:
# 使用正交随机特征
projection_matrix = create_orthogonal_matrix(d_model, projection_dim)
else:
# 使用iid高斯向量
projection_matrix = torch.randn(d_model, projection_dim) / math.sqrt(projection_dim)
projection_matrix = projection_matrix.to(q.device)
# 计算随机特征
q_prime = torch.exp(q @ projection_matrix) / math.sqrt(projection_dim)
k_prime = torch.exp(k @ projection_matrix) / math.sqrt(projection_dim)
# 线性计算注意力
kv = torch.bmm(k_prime.transpose(-2, -1), v)
qkv = torch.bmm(q_prime, kv)
# 归一化
normalizer = torch.bmm(q_prime, k_prime.sum(dim=1).unsqueeze(-1))
return qkv / (normalizer + eps)
3. Linear Transformer
简化版线性注意力,使用固定的特征映射:
def linear_transformer_attention(q, k, v):
"""Linear Transformer的注意力实现"""
# 使用elu+1特征映射
def phi(x):
return F.elu(x) + 1
q = phi(q) # [b, n, d]
k = phi(k) # [b, n, d]
# 线性复杂度计算
kv = torch.einsum('bnd,bne->bde', k, v) # [b, d, d_v]
qkv = torch.einsum('bmd,bde->bme', q, kv) # [b, m, d_v]
# 归一化
normalizer = torch.einsum('bmd,bd->bm', q, k.sum(dim=1)) # [b, m]
return qkv / normalizer.unsqueeze(-1) # [b, m, d_v]
4. Efficient Attention
Facebook (Meta) 研究的高效注意力实现:
def efficient_attention(q, k, v, dropout=0.1, chunk_size=128):
"""分块计算的高效线性注意力"""
# 获取维度
batch_size, seq_len, num_heads, head_dim = q.shape
# 重塑为4D张量
q = q.view(batch_size, num_heads, seq_len, head_dim)
k = k.view(batch_size, num_heads, seq_len, head_dim)
v = v.view(batch_size, num_heads, seq_len, head_dim)
# 计算注意力,分块处理以节省内存
outputs = []
for i in range(0, seq_len, chunk_size):
# 获取当前块
q_chunk = q[:, :, i:i+chunk_size]
chunk_len = q_chunk.size(2)
# 计算当前块的输出
scores = torch.matmul(q_chunk, k.transpose(-2, -1)) / math.sqrt(head_dim)
weights = F.softmax(scores, dim=-1)
if dropout > 0:
weights = F.dropout(weights, p=dropout)
chunk_output = torch.matmul(weights, v)
outputs.append(chunk_output)
# 合并所有块的输出
output = torch.cat(outputs, dim=2)
# 重塑回原始形状
output = output.view(batch_size, seq_len, num_heads * head_dim)
return output
核函数在线性注意力中的应用
线性注意力可以通过核方法解释,其中注意力被视为核函数:
常见的核函数选择:
1. 正核函数 (Positive Kernels)
def elu_kernel(x):
"""ELU+1核函数"""
return F.elu(x) + 1
def softmax_kernel(x, y, dim):
"""Softmax可以被作为一种核函数"""
# 计算点积
dots = torch.matmul(x, y.transpose(-2, -1))
# 应用Softmax
return F.softmax(dots, dim=dim)
2. 随机核函数
Performer使用随机特征近似核函数:
def random_fourier_features(x, n_features=256, sigma=1.0):
"""随机傅立叶特征"""
# 创建随机投影矩阵
proj = torch.randn(x.size(-1), n_features) * sigma
proj = proj.to(x.device)
# 投影输入向量
x_proj = x @ proj
# 应用激活函数
x_feats = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
return x_feats / math.sqrt(n_features)
线性注意力的理论分析
1. 误差分析
线性注意力与原始注意力的近似误差:
def approximation_error(q, k, v, linear_attn_fn, std_attn_fn):
"""计算线性注意力与标准注意力的近似误差"""
# 计算标准注意力
std_output = std_attn_fn(q, k, v)
# 计算线性注意力
linear_output = linear_attn_fn(q, k, v)
# 计算误差
error = torch.norm(std_output - linear_output, p='fro') / torch.norm(std_output, p='fro')
return error.item()
2. 计算复杂度比较
def complexity_comparison(seq_lengths, d_model=512):
"""比较不同序列长度下的计算复杂度"""
standard_complexity = [seq_len**2 * d_model for seq_len in seq_lengths]
linear_complexity = [seq_len * d_model**2 for seq_len in seq_lengths]
# 计算加速比
speedup = [std / lin for std, lin in zip(standard_complexity, linear_complexity)]
return {
"序列长度": seq_lengths,
"标准注意力FLOPs": standard_complexity,
"线性注意力FLOPs": linear_complexity,
"加速比": speedup
}
线性注意力的优化实践
1. 内存优化
def memory_efficient_linear_attention(q, k, v, chunk_size=1024):
"""使用分块策略节省内存的线性注意力"""
# 获取维度
batch_size, seq_len, head_dim = q.shape
# 应用特征映射
q = F.elu(q) + 1
k = F.elu(k) + 1
# 初始化KV累加器
kv = torch.zeros(batch_size, head_dim, head_dim, device=q.device)
# 分块累积K^T·V
for i in range(0, seq_len, chunk_size):
k_chunk = k[:, i:i+chunk_size]
v_chunk = v[:, i:i+chunk_size]
kv += torch.bmm(k_chunk.transpose(-2, -1), v_chunk)
# 计算输出
output = torch.bmm(q, kv)
# 计算归一化因子
k_sum = k.sum(dim=1).unsqueeze(-1) # [b, d, 1]
normalizer = torch.bmm(q, k_sum) # [b, n, 1]
return output / (normalizer + 1e-6)
2. 混合注意力策略
结合线性和标准注意力的优点:
class HybridAttention(nn.Module):
"""混合使用线性和标准注意力"""
def __init__(self, d_model, n_heads, linear_threshold=512):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.linear_threshold = linear_threshold
# 投影层
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
batch_size, seq_len, _ = q.shape
# 投影并重塑
q = self.q_proj(q).view(batch_size, seq_len, self.n_heads, self.head_dim)
k = self.k_proj(k).view(batch_size, seq_len, self.n_heads, self.head_dim)
v = self.v_proj(v).view(batch_size, seq_len, self.n_heads, self.head_dim)
# 选择注意力类型
if seq_len > self.linear_threshold:
# 使用线性注意力
output = self._linear_attention(q, k, v)
else:
# 使用标准注意力
output = self._standard_attention(q, k, v, mask)
# 投影输出
output = output.reshape(batch_size, seq_len, self.d_model)
output = self.out_proj(output)
return output
def _standard_attention(self, q, k, v, mask=None):
# 标准注意力实现...
pass
def _linear_attention(self, q, k, v):
# 线性注意力实现...
pass
大模型中的线性注意力应用
1. Longformer
Facebook开发的长文本处理模型:
def longformer_attention(q, k, v, local_attn_window=512, global_tokens=None):
"""Longformer的局部+全局稀疏注意力模式"""
batch_size, seq_len, num_heads, head_dim = q.shape
# 初始化输出
output = torch.zeros_like(q)
# 局部滑动窗口注意力
for i in range(seq_len):
# 确定窗口范围
window_start = max(0, i - local_attn_window // 2)
window_end = min(seq_len, i + local_attn_window // 2 + 1)
# 计算局部注意力
local_q = q[:, i:i+1]
local_k = k[:, window_start:window_end]
local_v = v[:, window_start:window_end]
# 标准注意力计算
scores = torch.matmul(local_q, local_k.transpose(-2, -1)) / math.sqrt(head_dim)
weights = F.softmax(scores, dim=-1)
local_output = torch.matmul(weights, local_v)
output[:, i:i+1] = local_output
# 全局注意力 (如果指定)
if global_tokens is not None:
for idx in global_tokens:
# 计算全局token的注意力
global_q = q[:, idx:idx+1]
scores = torch.matmul(global_q, k.transpose(-2, -1)) / math.sqrt(head_dim)
weights = F.softmax(scores, dim=-1)
global_output = torch.matmul(weights, v)
# 更新输出
output[:, idx:idx+1] = global_output
return output
2. Reformer
Google开发的基于LSH(局部敏感哈希)的高效注意力:
def lsh_attention(q, k, v, n_buckets=64, n_rounds=4):
"""基于LSH的局部敏感哈希注意力"""
batch_size, seq_len, d_model = q.shape
device = q.device
# 使用相同的键和查询进行哈希
x = q # [b, n, d]
# 进行多轮哈希以提高准确性
outputs = []
for r in range(n_rounds):
# 创建随机投影向量
random_rotations = torch.randn(d_model, n_buckets // 2, device=device)
# 计算哈希桶
rotated_x = torch.einsum('bnd,df->bnf', x, random_rotations)
bucket_indices = torch.argmax(torch.cat([rotated_x, -rotated_x], dim=-1), dim=-1)
# 按桶索引排序
sorted_indices = bucket_indices.argsort(dim=1)
x_sorted = batched_index_select(x, sorted_indices)
bucket_indices_sorted = batched_index_select(bucket_indices, sorted_indices)
# 找到桶的边界
boundaries = (bucket_indices_sorted[:, 1:] != bucket_indices_sorted[:, :-1]).long()
segment_ids = torch.cat([torch.zeros((batch_size, 1), device=device, dtype=torch.long),
torch.cumsum(boundaries, dim=1)], dim=1)
# 计算分块注意力
# ...实现分块注意力计算...
# 将结果按原始顺序排列
inv_sorted_indices = torch.argsort(sorted_indices, dim=1)
round_output = batched_index_select(output_sorted, inv_sorted_indices)
outputs.append(round_output)
# 平均多轮结果
output = torch.stack(outputs).mean(dim=0)
return output
3. xFormers中的实现
Facebook开发的高效Transformer库,整合了多种线性注意力方法:
def xformers_linear_attention(q, k, v, method="favor"):
"""xFormers库中的线性注意力实现示例"""
if method == "favor":
# FAVOR+方法
return favor_attention(q, k, v)
elif method == "kernel":
# 核函数方法
return kernel_attention(q, k, v)
elif method == "nystrom":
# Nyström方法
return nystrom_attention(q, k, v)
else:
raise ValueError(f"不支持的线性注意力方法: {method}")
线性注意力与标准注意力的比较
特性 | 标准Softmax注意力 | 线性注意力 |
---|---|---|
计算复杂度 | ||
内存需求 | ||
序列长度支持 | 有限 (通常 < 4K) | 更长 (可达数十万) |
表达能力 | 强 | 相对较弱 |
全局依赖建模 | 直接建模 | 间接近似 |
适用场景 | 中短文本,精确建模 | 长文本,高效处理 |
实施难度 | 相对简单 | 较复杂 |
小结
线性注意力是大型语言模型处理长序列的重要技术突破:
- 核心优势:将传统注意力的
复杂度降低到 ,显著提高长序列处理能力 - 主要方法:包括Linearized Attention、Performer、Linear Transformer等,通过矩阵分解或核函数近似实现线性复杂度
- 应用场景:特别适用于长文本处理、文档分析、长对话历史等场景
- 权衡取舍:线性注意力通常以略微降低表达能力为代价,获得更高的计算效率
随着大模型应用场景逐渐拓展到更长的输入场景,线性注意力及其变体将发挥越来越重要的作用,成为Transformer架构发展的关键方向之一。