Transformer模型结构
预计学习时间:30分钟
Transformer模型结构是理解现代大型语言模型的基础。Transformer采用编码器-解码器架构,但其核心创新在于使用多头自注意力机制替代了循环神经网络,同时配合位置编码、残差连接和层归一化等技术,形成了一个高效、可扩展的序列处理架构。
Transformer整体结构概览
Transformer模型由编码器和解码器两大部分组成,每部分又包含多个相同结构的层堆叠而成:
编码器-解码器框架
标准Transformer包括:
- 编码器(Encoder):6个相同的层堆叠
- 解码器(Decoder):6个相同的层堆叠
- 连接机制:解码器通过注意力访问编码器输出
# Transformer基本框架的PyTorch实现
import torch
import torch.nn as nn
class Transformer(nn.Module):
def __init__(
self,
src_vocab_size,
trg_vocab_size,
src_pad_idx,
trg_pad_idx,
embed_size=512,
num_layers=6,
forward_expansion=4,
heads=8,
dropout=0.1,
device="cpu",
max_length=100,
):
super(Transformer, self).__init__()
self.encoder = Encoder(
src_vocab_size,
embed_size,
num_layers,
heads,
forward_expansion,
dropout,
max_length,
device,
)
self.decoder = Decoder(
trg_vocab_size,
embed_size,
num_layers,
heads,
forward_expansion,
dropout,
max_length,
device,
)
self.src_pad_idx = src_pad_idx
self.trg_pad_idx = trg_pad_idx
self.device = device
def make_src_mask(self, src):
# src: [batch_size, src_len]
src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
# src_mask: [batch_size, 1, 1, src_len]
return src_mask
def make_trg_mask(self, trg):
# trg: [batch_size, trg_len]
trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
# trg_pad_mask: [batch_size, 1, 1, trg_len]
trg_len = trg.shape[1]
trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len))).to(self.device)
# trg_sub_mask: [trg_len, trg_len]
trg_mask = trg_pad_mask & trg_sub_mask
# trg_mask: [batch_size, 1, trg_len, trg_len]
return trg_mask
def forward(self, src, trg):
src_mask = self.make_src_mask(src)
trg_mask = self.make_trg_mask(trg)
enc_src = self.encoder(src, src_mask)
out = self.decoder(trg, enc_src, src_mask, trg_mask)
return out
编码器详细结构
编码器负责将输入序列转换为连续表示。每个编码器层包含两个主要子层:
编码器层组成
-
多头自注意力子层
- 允许编码器关注输入序列的不同位置
- 每个位置可以汇集所有位置的信息
-
位置全连接前馈网络
- 两个线性变换,中间带有ReLU激活
- 对每个位置独立且相同地应用
class EncoderLayer(nn.Module):
def __init__(self, embed_size, heads, forward_expansion, dropout):
super(EncoderLayer, self).__init__()
self.attention = MultiHeadAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size),
)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, mask=None):
# 自注意力
attention = self.attention(value, key, query, mask)
# 第一个残差连接和层归一化
x = self.norm1(attention + query)
x = self.dropout(x)
# 前馈网络
forward = self.feed_forward(x)
# 第二个残差连接和层归一化
out = self.norm2(forward + x)
out = self.dropout(out)
return out
每个子层周围都有残差连接和层归一化,顺序是:子层输出 + 子层输入,然后进行层归一化。这种结构有助于训练更深的网络和改善梯度流动。
解码器详细结构
解码器负责生成输出序列,在处理过程中能够访问编码器的输出:
解码器层组成
-
掩码多头自注意力子层
- 确保当前位置只能访问到过去的位置
- 通过掩码实现自回归特性
-
编码器-解码器多头注意力子层
- 允许解码器关注输入序列的相关部分
- Query来自上一层,Key和Value来自编码器输出
-
位置全连接前馈网络
- 与编码器的前馈网络结构相同
class DecoderLayer(nn.Module):
def __init__(self, embed_size, heads, forward_expansion, dropout):
super(DecoderLayer, self).__init__()
self.attention = MultiHeadAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.encoder_attention = MultiHeadAttention(embed_size, heads)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size),
)
self.norm3 = nn.LayerNorm(embed_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask, trg_mask):
# 自注意力
attention = self.attention(x, x, x, trg_mask)
# 第一个残差连接和层归一化
x = self.norm1(attention + x)
x = self.dropout(x)
# 编码器-解码器注意力
enc_attention = self.encoder_attention(
enc_output, enc_output, x, src_mask
)
# 第二个残差连接和层归一化
x = self.norm2(enc_attention + x)
x = self.dropout(x)
# 前馈网络
forward = self.feed_forward(x)
# 第三个残差连接和层归一化
out = self.norm3(forward + x)
out = self.dropout(out)
return out
多头自注意力机制
多头自注意力是Transformer的核心创新,它允许模型同时关注不同表示子空间中的信息:
工作原理
- 线性投影:将输入向量投影到query、key、value空间
- 拆分多头:将每个投影分割为多个"头"
- 并行注意力:每个头独立计算注意力
- 拼接与投影:将多头输出拼接并投影回原始维度
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, heads):
super(MultiHeadAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (self.head_dim * heads == embed_size), "Embedding size must be divisible by heads"
# 线性投影层
self.values = nn.Linear(embed_size, embed_size)
self.keys = nn.Linear(embed_size, embed_size)
self.queries = nn.Linear(embed_size, embed_size)
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, values, keys, queries, mask=None):
batch_size = queries.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]
# 线性投影
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
# 将投影结果分割为多头
values = values.reshape(batch_size, value_len, self.heads, self.head_dim)
keys = keys.reshape(batch_size, key_len, self.heads, self.head_dim)
queries = queries.reshape(batch_size, query_len, self.heads, self.head_dim)
# 计算注意力能量
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# energy: [batch_size, heads, query_len, key_len]
# 缩放
energy = energy / (self.head_dim ** 0.5)
# 应用掩码
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
# 注意力权重
attention = torch.softmax(energy, dim=3)
# 加权求和
out = torch.einsum("nhql,nlhd->nqhd", [attention, values])
# out: [batch_size, query_len, heads, head_dim]
# 重塑并投影回原始维度
out = out.reshape(batch_size, query_len, self.embed_size)
out = self.fc_out(out)
return out
"多头注意力机制允许模型共同关注来自不同位置和不同表示子空间的信息,这极大增强了模型的表达能力。" — Vaswani et al.
位置全连接前馈网络
位置全连接前馈网络是Transformer中另一个关键组件:
特点与作用
- 由两个线性变换组成,中间有ReLU激活
- 独立应用于每个位置
- 内部维度通常是输入维度的4倍
- 引入非线性并增强模型容量
# 前馈网络实现
def position_wise_feed_forward(d_model, d_ff, dropout=0.1):
return nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
位置编码
由于Transformer没有循环结构,需要额外的机制来表示位置信息:
正弦余弦位置编码
- 使用正弦和余弦函数的组合
- 不同频率的周期函数对应不同位置
- 允许模型学习相对位置关系
位置编码特点 | 说明 |
---|---|
唯一性 | 每个位置有唯一表示 |
一致性 | 相同距离的位置对有类似的关系 |
可外推性 | 可扩展到未见过的序列长度 |
确定性 | 固定函数,不需要学习 |
Transformer的关键组件配置
原始Transformer论文中的核心参数设置:
组件 | 参数 | 值 |
---|---|---|
整体架构 | 编码器/解码器层数 | 6 |
嵌入层 | 嵌入维度(d_model) | 512 |
多头注意力 | 头数(h) | 8 |
多头注意力 | 每头维度(d_k, d_v) | 64 |
前馈网络 | 内部维度(d_ff) | 2048 |
Dropout | 比率 | 0.1 |
Transformer模型结构的优雅设计使其成为后来大型语言模型的基础架构。通过理解其组件的工作原理和相互作用,我们能更好地理解现代LLM的内部运作机制,为后续的模型优化和应用提供坚实基础。