title: 交叉注意力机制 description: 理解交叉注意力机制的工作原理及其在编码器-解码器架构中的应用 time: 35 sequence: 8

交叉注意力机制

交叉注意力的概念

交叉注意力(Cross-attention)是Transformer架构中连接编码器和解码器的关键机制。与自注意力不同,交叉注意力允许模型在两个不同序列之间建立联系,这在机器翻译等序列到序列任务中尤为重要。

交叉注意力的核心思想是:

  • 解码器的查询(Q)来自解码器层的输入
  • 键(K)和值(V)来自编码器的输出

这种设计使解码器能够"查看"输入序列的完整表示,并在生成输出时有选择地关注相关的输入信息。

交叉注意力是编码器-解码器架构中信息流动的桥梁,允许解码器在生成每个输出标记时动态关注源序列的相关部分。

交叉注意力的工作原理

交叉注意力的计算过程与自注意力类似,但关键区别在于查询、键和值的来源:

  1. 查询矩阵(Q): 从当前解码器层的输入中生成
  2. 键矩阵(K)和值矩阵(V): 从编码器的最终输出中生成

交叉注意力的计算公式为:

其中:

  • Q是解码器层的查询表示
  • K和V是编码器输出的键和值表示
  • 是键向量的维度

交叉注意力的优势

交叉注意力机制具有以下优势:

  1. 捕捉序列间依赖关系: 允许解码器输出根据输入序列的相关部分进行条件生成
  2. 动态关注: 解码器在生成每个标记时可以灵活关注输入序列的不同部分
  3. 信息桥接: 有效地将编码器提取的信息传递给解码器
  4. 长距离依赖建模: 处理源序列和目标序列之间的长距离依赖

交叉注意力在NLP任务中的应用

交叉注意力在多种NLP任务中发挥着重要作用:

任务交叉注意力的作用
机器翻译将源语言句子映射到目标语言句子
文本摘要关注原文中的关键信息以生成摘要
问答系统连接问题和上下文以提取答案
多模态任务连接不同模态(如图像和文本)的表示

代码实现

以下是使用PyTorch实现交叉注意力机制的简化示例:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class CrossAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super(CrossAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # 为解码器查询创建线性层
        self.W_q = nn.Linear(d_model, d_model)
        # 为编码器键和值创建线性层
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        self.out_proj = nn.Linear(d_model, d_model)
        
    def split_heads(self, x):
        # 形状变换: (batch_size, seq_len, d_model) -> (batch_size, n_heads, seq_len, d_k)
        batch_size, seq_len = x.size(0), x.size(1)
        return x.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
    def forward(self, decoder_input, encoder_output):
        # decoder_input: [batch_size, tgt_len, d_model]
        # encoder_output: [batch_size, src_len, d_model]
        
        # 获取查询、键和值
        q = self.W_q(decoder_input)  # 从解码器输入
        k = self.W_k(encoder_output)  # 从编码器输出
        v = self.W_v(encoder_output)  # 从编码器输出
        
        # 拆分多头
        q = self.split_heads(q)  # [batch_size, n_heads, tgt_len, d_k]
        k = self.split_heads(k)  # [batch_size, n_heads, src_len, d_k]
        v = self.split_heads(v)  # [batch_size, n_heads, src_len, d_k]
        
        # 缩放点积注意力
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        
        # 应用注意力权重
        context = torch.matmul(attn_weights, v)
        
        # 重塑并投影
        context = context.transpose(1, 2).contiguous().view(
            context.size(0), -1, self.d_model)
        output = self.out_proj(context)
        
        return output, attn_weights

# 使用示例
batch_size = 2
src_len = 10  # 源序列长度
tgt_len = 5   # 目标序列长度
d_model = 512 # 模型维度
n_heads = 8   # 注意力头数

# 模拟编码器输出和解码器输入
decoder_input = torch.randn(batch_size, tgt_len, d_model)
encoder_output = torch.randn(batch_size, src_len, d_model)

cross_attn = CrossAttention(d_model, n_heads)
output, attention_weights = cross_attn(decoder_input, encoder_output)

print(f"输出形状: {output.shape}")  # [batch_size, tgt_len, d_model]
print(f"注意力权重形状: {attention_weights.shape}")  # [batch_size, n_heads, tgt_len, src_len]

交叉注意力与自注意力的比较

特性交叉注意力自注意力
查询、键、值来源Q来自解码器,K和V来自编码器Q、K、V均来自同一序列
主要功能连接不同序列的信息捕捉单一序列内的依赖关系
应用位置编码器-解码器架构的连接层编码器和解码器内部
信息流向从源序列到目标序列序列内部的信息交互

小结

交叉注意力机制是编码器-解码器架构中的关键组件,它使模型能够在生成输出时动态地关注输入序列的相关部分。通过将解码器的查询与编码器的键和值相结合,交叉注意力为序列到序列任务提供了强大的建模能力,成为现代NLP系统中不可或缺的一部分。

理解交叉注意力对于掌握Transformer架构和编码器-解码器模型的工作原理至关重要,尤其是在机器翻译、文本摘要等需要在两个不同序列之间建立联系的任务中。