Transformer模型结构

预计学习时间:30分钟

Transformer模型结构是理解现代大型语言模型的基础。Transformer采用编码器-解码器架构,但其核心创新在于使用多头自注意力机制替代了循环神经网络,同时配合位置编码、残差连接和层归一化等技术,形成了一个高效、可扩展的序列处理架构。

Transformer整体结构概览

Transformer模型由编码器和解码器两大部分组成,每部分又包含多个相同结构的层堆叠而成:

编码器-解码器框架

标准Transformer包括:

  • 编码器(Encoder):6个相同的层堆叠
  • 解码器(Decoder):6个相同的层堆叠
  • 连接机制:解码器通过注意力访问编码器输出

Transformer整体架构

# Transformer基本框架的PyTorch实现
import torch
import torch.nn as nn

class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        trg_pad_idx,
        embed_size=512,
        num_layers=6,
        forward_expansion=4,
        heads=8,
        dropout=0.1,
        device="cpu",
        max_length=100,
    ):
        super(Transformer, self).__init__()
        
        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_length,
            device,
        )
        
        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_length,
            device,
        )
        
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
        
    def make_src_mask(self, src):
        # src: [batch_size, src_len]
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # src_mask: [batch_size, 1, 1, src_len]
        return src_mask
    
    def make_trg_mask(self, trg):
        # trg: [batch_size, trg_len]
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        # trg_pad_mask: [batch_size, 1, 1, trg_len]
        
        trg_len = trg.shape[1]
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len))).to(self.device)
        # trg_sub_mask: [trg_len, trg_len]
        
        trg_mask = trg_pad_mask & trg_sub_mask
        # trg_mask: [batch_size, 1, trg_len, trg_len]
        return trg_mask
    
    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out

编码器详细结构

编码器负责将输入序列转换为连续表示。每个编码器层包含两个主要子层:

编码器层组成

  1. 多头自注意力子层

    • 允许编码器关注输入序列的不同位置
    • 每个位置可以汇集所有位置的信息
  2. 位置全连接前馈网络

    • 两个线性变换,中间带有ReLU激活
    • 对每个位置独立且相同地应用
class EncoderLayer(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout):
        super(EncoderLayer, self).__init__()
        
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, value, key, query, mask=None):
        # 自注意力
        attention = self.attention(value, key, query, mask)
        
        # 第一个残差连接和层归一化
        x = self.norm1(attention + query)
        x = self.dropout(x)
        
        # 前馈网络
        forward = self.feed_forward(x)
        
        # 第二个残差连接和层归一化
        out = self.norm2(forward + x)
        out = self.dropout(out)
        
        return out

每个子层周围都有残差连接和层归一化,顺序是:子层输出 + 子层输入,然后进行层归一化。这种结构有助于训练更深的网络和改善梯度流动。

解码器详细结构

解码器负责生成输出序列,在处理过程中能够访问编码器的输出:

解码器层组成

  1. 掩码多头自注意力子层

    • 确保当前位置只能访问到过去的位置
    • 通过掩码实现自回归特性
  2. 编码器-解码器多头注意力子层

    • 允许解码器关注输入序列的相关部分
    • Query来自上一层,Key和Value来自编码器输出
  3. 位置全连接前馈网络

    • 与编码器的前馈网络结构相同
class DecoderLayer(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout):
        super(DecoderLayer, self).__init__()
        
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        
        self.encoder_attention = MultiHeadAttention(embed_size, heads)
        self.norm2 = nn.LayerNorm(embed_size)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )
        
        self.norm3 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, trg_mask):
        # 自注意力
        attention = self.attention(x, x, x, trg_mask)
        
        # 第一个残差连接和层归一化
        x = self.norm1(attention + x)
        x = self.dropout(x)
        
        # 编码器-解码器注意力
        enc_attention = self.encoder_attention(
            enc_output, enc_output, x, src_mask
        )
        
        # 第二个残差连接和层归一化
        x = self.norm2(enc_attention + x)
        x = self.dropout(x)
        
        # 前馈网络
        forward = self.feed_forward(x)
        
        # 第三个残差连接和层归一化
        out = self.norm3(forward + x)
        out = self.dropout(out)
        
        return out

多头自注意力机制

多头自注意力是Transformer的核心创新,它允许模型同时关注不同表示子空间中的信息:

工作原理

  1. 线性投影:将输入向量投影到query、key、value空间
  2. 拆分多头:将每个投影分割为多个"头"
  3. 并行注意力:每个头独立计算注意力
  4. 拼接与投影:将多头输出拼接并投影回原始维度
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        assert (self.head_dim * heads == embed_size), "Embedding size must be divisible by heads"
        
        # 线性投影层
        self.values = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.queries = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)
        
    def forward(self, values, keys, queries, mask=None):
        batch_size = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]
        
        # 线性投影
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)
        
        # 将投影结果分割为多头
        values = values.reshape(batch_size, value_len, self.heads, self.head_dim)
        keys = keys.reshape(batch_size, key_len, self.heads, self.head_dim)
        queries = queries.reshape(batch_size, query_len, self.heads, self.head_dim)
        
        # 计算注意力能量
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # energy: [batch_size, heads, query_len, key_len]
        
        # 缩放
        energy = energy / (self.head_dim ** 0.5)
        
        # 应用掩码
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        
        # 注意力权重
        attention = torch.softmax(energy, dim=3)
        
        # 加权求和
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values])
        # out: [batch_size, query_len, heads, head_dim]
        
        # 重塑并投影回原始维度
        out = out.reshape(batch_size, query_len, self.embed_size)
        out = self.fc_out(out)
        
        return out

"多头注意力机制允许模型共同关注来自不同位置和不同表示子空间的信息,这极大增强了模型的表达能力。" — Vaswani et al.

位置全连接前馈网络

位置全连接前馈网络是Transformer中另一个关键组件:

特点与作用

  • 由两个线性变换组成,中间有ReLU激活
  • 独立应用于每个位置
  • 内部维度通常是输入维度的4倍
  • 引入非线性并增强模型容量
# 前馈网络实现
def position_wise_feed_forward(d_model, d_ff, dropout=0.1):
    return nn.Sequential(
        nn.Linear(d_model, d_ff),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(d_ff, d_model)
    )

位置编码

由于Transformer没有循环结构,需要额外的机制来表示位置信息:

正弦余弦位置编码

  • 使用正弦和余弦函数的组合
  • 不同频率的周期函数对应不同位置
  • 允许模型学习相对位置关系
位置编码特点说明
唯一性每个位置有唯一表示
一致性相同距离的位置对有类似的关系
可外推性可扩展到未见过的序列长度
确定性固定函数,不需要学习

Transformer的关键组件配置

原始Transformer论文中的核心参数设置:

组件参数
整体架构编码器/解码器层数6
嵌入层嵌入维度(d_model)512
多头注意力头数(h)8
多头注意力每头维度(d_k, d_v)64
前馈网络内部维度(d_ff)2048
Dropout比率0.1

Transformer模型结构的优雅设计使其成为后来大型语言模型的基础架构。通过理解其组件的工作原理和相互作用,我们能更好地理解现代LLM的内部运作机制,为后续的模型优化和应用提供坚实基础。