title: 交叉注意力机制 description: 理解交叉注意力机制的工作原理及其在编码器-解码器架构中的应用 time: 35 sequence: 8
交叉注意力机制
交叉注意力的概念
交叉注意力(Cross-attention)是Transformer架构中连接编码器和解码器的关键机制。与自注意力不同,交叉注意力允许模型在两个不同序列之间建立联系,这在机器翻译等序列到序列任务中尤为重要。
交叉注意力的核心思想是:
- 解码器的查询(Q)来自解码器层的输入
- 键(K)和值(V)来自编码器的输出
这种设计使解码器能够"查看"输入序列的完整表示,并在生成输出时有选择地关注相关的输入信息。
交叉注意力是编码器-解码器架构中信息流动的桥梁,允许解码器在生成每个输出标记时动态关注源序列的相关部分。
交叉注意力的工作原理
交叉注意力的计算过程与自注意力类似,但关键区别在于查询、键和值的来源:
- 查询矩阵(Q): 从当前解码器层的输入中生成
- 键矩阵(K)和值矩阵(V): 从编码器的最终输出中生成
交叉注意力的计算公式为:
其中:
- Q是解码器层的查询表示
- K和V是编码器输出的键和值表示
是键向量的维度
交叉注意力的优势
交叉注意力机制具有以下优势:
- 捕捉序列间依赖关系: 允许解码器输出根据输入序列的相关部分进行条件生成
- 动态关注: 解码器在生成每个标记时可以灵活关注输入序列的不同部分
- 信息桥接: 有效地将编码器提取的信息传递给解码器
- 长距离依赖建模: 处理源序列和目标序列之间的长距离依赖
交叉注意力在NLP任务中的应用
交叉注意力在多种NLP任务中发挥着重要作用:
任务 | 交叉注意力的作用 |
---|---|
机器翻译 | 将源语言句子映射到目标语言句子 |
文本摘要 | 关注原文中的关键信息以生成摘要 |
问答系统 | 连接问题和上下文以提取答案 |
多模态任务 | 连接不同模态(如图像和文本)的表示 |
代码实现
以下是使用PyTorch实现交叉注意力机制的简化示例:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class CrossAttention(nn.Module):
def __init__(self, d_model, n_heads):
super(CrossAttention, self).__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# 为解码器查询创建线性层
self.W_q = nn.Linear(d_model, d_model)
# 为编码器键和值创建线性层
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def split_heads(self, x):
# 形状变换: (batch_size, seq_len, d_model) -> (batch_size, n_heads, seq_len, d_k)
batch_size, seq_len = x.size(0), x.size(1)
return x.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
def forward(self, decoder_input, encoder_output):
# decoder_input: [batch_size, tgt_len, d_model]
# encoder_output: [batch_size, src_len, d_model]
# 获取查询、键和值
q = self.W_q(decoder_input) # 从解码器输入
k = self.W_k(encoder_output) # 从编码器输出
v = self.W_v(encoder_output) # 从编码器输出
# 拆分多头
q = self.split_heads(q) # [batch_size, n_heads, tgt_len, d_k]
k = self.split_heads(k) # [batch_size, n_heads, src_len, d_k]
v = self.split_heads(v) # [batch_size, n_heads, src_len, d_k]
# 缩放点积注意力
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn_weights = F.softmax(scores, dim=-1)
# 应用注意力权重
context = torch.matmul(attn_weights, v)
# 重塑并投影
context = context.transpose(1, 2).contiguous().view(
context.size(0), -1, self.d_model)
output = self.out_proj(context)
return output, attn_weights
# 使用示例
batch_size = 2
src_len = 10 # 源序列长度
tgt_len = 5 # 目标序列长度
d_model = 512 # 模型维度
n_heads = 8 # 注意力头数
# 模拟编码器输出和解码器输入
decoder_input = torch.randn(batch_size, tgt_len, d_model)
encoder_output = torch.randn(batch_size, src_len, d_model)
cross_attn = CrossAttention(d_model, n_heads)
output, attention_weights = cross_attn(decoder_input, encoder_output)
print(f"输出形状: {output.shape}") # [batch_size, tgt_len, d_model]
print(f"注意力权重形状: {attention_weights.shape}") # [batch_size, n_heads, tgt_len, src_len]
交叉注意力与自注意力的比较
特性 | 交叉注意力 | 自注意力 |
---|---|---|
查询、键、值来源 | Q来自解码器,K和V来自编码器 | Q、K、V均来自同一序列 |
主要功能 | 连接不同序列的信息 | 捕捉单一序列内的依赖关系 |
应用位置 | 编码器-解码器架构的连接层 | 编码器和解码器内部 |
信息流向 | 从源序列到目标序列 | 序列内部的信息交互 |
小结
交叉注意力机制是编码器-解码器架构中的关键组件,它使模型能够在生成输出时动态地关注输入序列的相关部分。通过将解码器的查询与编码器的键和值相结合,交叉注意力为序列到序列任务提供了强大的建模能力,成为现代NLP系统中不可或缺的一部分。
理解交叉注意力对于掌握Transformer架构和编码器-解码器模型的工作原理至关重要,尤其是在机器翻译、文本摘要等需要在两个不同序列之间建立联系的任务中。