自回归语言模型

预计学习时间:20分钟

自回归语言模型(Autoregressive Language Models)是一类通过预测序列中下一个元素来生成文本的模型,它们依赖于之前生成的所有内容来预测接下来的内容。

自回归模型的基本原理

自回归语言模型基于条件概率分解原理,将文本序列的联合概率分解为条件概率的连乘:

在这个公式中,每个词的预测都依赖于其之前的所有词。

自回归模型的特点

  • 单向上下文:只能利用左侧(之前)的信息预测下一个词
  • 生成过程:自左向右逐词生成
  • 天然适合文本生成任务:可以自然地按顺序产生连贯文本

自回归模型的实现

自回归语言模型的实现方式随着深度学习的发展而演进:

  1. 传统统计方法:N-gram模型
  2. 神经网络实现
    • RNN/LSTM/GRU:通过循环结构捕捉序列信息
    • Transformer解码器:通过掩码自注意力实现自回归特性
# 使用PyTorch实现简单的自回归模型示例
import torch
import torch.nn as nn

class SimpleAutoregressive(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(SimpleAutoregressive, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, x):
        # x shape: [batch_size, sequence_length]
        embedded = self.embedding(x)  # [batch_size, sequence_length, embedding_dim]
        output, (hidden, cell) = self.rnn(embedded)  # output: [batch_size, sequence_length, hidden_dim]
        logits = self.fc(output)  # [batch_size, sequence_length, vocab_size]
        return logits
    
    def generate(self, start_tokens, max_length):
        """生成文本序列"""
        self.eval()
        current_tokens = start_tokens.clone()
        
        for _ in range(max_length):
            # 获取预测
            logits = self(current_tokens)
            # 取最后一个时间步的预测
            next_token_logits = logits[:, -1, :]
            # 采样下一个词
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
            # 将新词添加到序列中
            current_tokens = torch.cat([current_tokens, next_token], dim=1)
            
        return current_tokens

"自回归语言模型是当代大型语言模型的核心架构基础,GPT系列模型都采用这种方式构建。"

自回归模型的优缺点

优点

  • 生成流畅自然的文本:逐词生成符合人类阅读和写作习惯
  • 实现简单直观:训练和推理逻辑清晰
  • 适合开放式生成任务:故事创作、对话等

缺点

  • 存在曝光偏差问题(Exposure Bias):训练与推理时的条件分布不一致
  • 无法利用双向上下文:无法同时考虑左右两侧的信息
  • 生成过程无法并行:每次只能预测一个token,推理速度受限
模型类型代表模型关键特点
基于RNN的自回归模型AWD-LSTM长期依赖处理能力有限
基于Transformer的自回归模型GPT系列, LLaMA可捕捉更长距离依赖,规模化潜力大