位置编码

预计学习时间:25分钟

位置编码是Transformer模型的关键组件,用于为模型提供序列中token的位置信息,因为自注意力机制本身不包含位置信息。

为什么需要位置编码?

Transformer中的自注意力机制是位置无关的,即:

如果我们打乱输入序列中token的顺序,自注意力层的计算结果仍然相同。然而,语言是有序的,单词或token的顺序对语义至关重要。

如果没有位置信息,"猫追狗"和"狗追猫"在自注意力中将被视为等价的,这显然是不正确的。

位置编码的主要方法

1. 绝对位置编码 (Absolute Positional Encoding)

学习型位置嵌入

最简单的方法是学习每个位置的嵌入向量:

class LearnedPositionalEmbedding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        self.positional_embeddings = nn.Parameter(torch.zeros(max_len, d_model))
        nn.init.normal_(self.positional_embeddings, mean=0, std=0.02)
        
    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        seq_len = x.size(1)
        return x + self.positional_embeddings[:seq_len, :]

正弦位置编码 (Sinusoidal)

原始Transformer论文中使用的正弦位置编码:

实现代码:

def sinusoidal_positional_encoding(max_len, d_model):
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    
    return pe

正弦位置编码的优势:

  • 可以扩展到未见过的序列长度
  • 提供不同频率的位置信号
  • 固定编码,不需要学习参数

正弦位置编码可视化

2. 相对位置编码 (Relative Positional Encoding)

相对位置编码关注的是token之间的相对距离,而不是它们在序列中的绝对位置。

Shaw等人的方法 (Transformer-XL)

在注意力计算中加入相对位置信息:

其中R是相对位置编码矩阵。

class RelativePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_rel_distance=32):
        super().__init__()
        self.max_rel_distance = max_rel_distance
        self.rel_embeddings = nn.Parameter(torch.zeros(2 * max_rel_distance + 1, d_model))
        nn.init.normal_(self.rel_embeddings, mean=0, std=0.02)
        
    def forward(self, q, k):
        # 计算q和k之间的相对位置
        # 实现相对位置的注意力计算
        # ...
        pass

3. 旋转位置编码 (RoPE - Rotary Position Embedding)

最近在现代大语言模型中流行的位置编码方法,将旋转矩阵应用于查询和键向量:

Misplaced &\cos m\theta_i & -\sin m\theta_i \\ \sin m\theta_i & \cos m\theta_i \end{pmatrix} \mathbf{q}_{m,i}$$ $$\mathbf{k}_{n,i}^{\theta} = \begin{pmatrix} \cos n\theta_i & -\sin n\theta_i \\ \sin n\theta_i & \cos n\theta_i \end{pmatrix} \mathbf{k}_{n,i}$$ ```python def apply_rotary_embedding(x, cos, sin, position_ids): # x: [batch_size, seq_len, num_heads, head_dim] # 将位置信息编码到x中 cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] # 将向量分成偶数和奇数部分 x_2d = x.view(*x.shape[:-1], -1, 2) x_2d_rot = torch.stack([-x_2d[..., 1::2, 0], x_2d[..., 1::2, 1]], dim=-1) x_2d = torch.stack([x_2d[..., ::2, 0], x_2d[..., ::2, 1]], dim=-1) # 应用旋转 result_2d = torch.zeros_like(x_2d) result_2d[..., 0] = x_2d[..., 0] * cos - x_2d[..., 1] * sin result_2d[..., 1] = x_2d[..., 0] * sin + x_2d[..., 1] * cos # 合并结果 # ... return x_rope ``` RoPE的优势: - 更好地捕捉相对位置关系 - 在长序列外推时表现更好 - 计算效率高 - 支持上下文扩展技术 ## 位置编码的演变与比较 不同大语言模型使用的位置编码: | 模型 | 位置编码类型 | 注意事项 | | --- | --- | --- | | BERT | 学习型绝对位置编码 | 限制最大序列长度为512 | | GPT-2 | 学习型绝对位置编码 | 限制最大序列长度为1024 | | T5 | 相对位置编码 | 使用有界相对位置窗口 | | Transformer-XL | 相对位置编码 | 优化长序列学习 | | ALBERT | 因式分解嵌入 | 降低参数量 | | LLaMA | 旋转位置编码(RoPE) | 更好的长度外推能力 | | PaLM | 旋转位置编码(RoPE) | 支持更长上下文 | ## 位置编码与上下文长度扩展 现代大语言模型需要处理超长上下文,因此位置编码的外推能力变得至关重要。 ### 位置插值法 (Position Interpolation) 通过在原有位置编码之间进行插值扩展上下文长度: ```python def position_interpolation(original_pe, target_length): # original_pe: [orig_length, dim] orig_length = original_pe.shape[0] # 创建插值点 orig_positions = torch.arange(orig_length).float() target_positions = torch.linspace(0, orig_length - 1, target_length).float() # 对每个维度进行线性插值 extended_pe = torch.zeros(target_length, original_pe.shape[1]) for i in range(original_pe.shape[1]): extended_pe[:, i] = torch.interp(target_positions, orig_positions, original_pe[:, i]) return extended_pe ``` ### RoPE缩放技术 针对旋转位置编码的特殊扩展方法: ```python def rope_scaling(theta, scaling_factor): """降低RoPE中的旋转频率,扩展位置编码范围""" return theta / scaling_factor ``` <Callout type="note"> NTK-aware缩放是一种更有效的RoPE缩放方法,它采用$\theta_i' = \theta_i \cdot base^{-2i/d \cdot (1-s)}$,其中$s$是缩放因子。 </Callout> ## 位置编码与其他组件的交互 位置编码与其他模型组件的关系: ### 1. 与自注意力机制的相互作用 在不同架构中位置编码应用的位置不同: - **前置型**: $\text{Attention}(\text{PE}(Q), \text{PE}(K), V)$ - **内嵌型**: 在计算注意力分数时融入位置信息 - **后置型**: 在注意力输出后添加位置信息 ### 2. 位置编码与长度泛化 位置编码直接影响模型的长度泛化能力: ```python def evaluate_length_generalization(model, position_encoding_type, test_lengths): """评估不同位置编码对长度泛化的影响""" results = {} for length in test_lengths: # 生成测试数据 test_data = generate_test_data(length) # 评估模型性能 performance = evaluate_model(model, test_data) results[length] = performance return results ``` ## 实际应用与最佳实践 ### 1. 选择适合的位置编码 根据任务特点选择位置编码: - 短序列任务: 简单的学习型或正弦位置编码 - 长序列任务: RoPE或高级相对位置编码 - 需要双向上下文: 适合使用相对位置编码 - 需要长度泛化: 首选RoPE ### 2. 位置编码实现技巧 ```python class PositionalEncodingFactory: @staticmethod def create(encoding_type, config): if encoding_type == "learned": return LearnedPositionalEmbedding(config.max_len, config.d_model) elif encoding_type == "sinusoidal": pe = sinusoidal_positional_encoding(config.max_len, config.d_model) return lambda x: x + pe.to(x.device) elif encoding_type == "relative": return RelativePositionalEncoding(config.d_model, config.max_rel_distance) elif encoding_type == "rope": return RotaryPositionalEmbedding(config.d_model) else: raise ValueError(f"Unknown positional encoding type: {encoding_type}") ``` ## 小结 位置编码是大语言模型的关键组件,负责为模型提供序列位置信息: 1. **绝对位置编码**适合处理固定长度的序列 2. **相对位置编码**提高了模型处理长序列的能力 3. **旋转位置编码(RoPE)**兼具计算效率和外推能力,是当前大型语言模型的主流选择 4. 位置编码直接影响模型的**上下文长度限制**和**长度泛化能力** 5. 随着大语言模型需要处理越来越长的上下文,位置编码技术仍在**持续演化**中 选择合适的位置编码需要权衡模型性能、计算效率和应用场景需求。