循环神经网络(RNN/LSTM/GRU)

核心定位: 处理序列数据的经典架构,从 1986 年 RNN 到 2023 年被 Transformer 大规模取代的演进史。 关联笔记: 神经网络神经网络层类型详解Transformer卷积神经网络


1. 为什么需要 RNN

1.1 序列数据的挑战

序列数据无处不在:

  • 文本:单词序列
  • 语音:音频帧序列
  • 视频:图像帧序列
  • 时间序列:股票价格、气温变化

传统神经网络的局限:

  • 固定输入长度: 无法处理可变长度序列
  • 无记忆能力: 每个输入独立处理
  • 参数爆炸: 长序列需要巨量参数

1.2 RNN 的核心思想

循环连接: 隐状态在时间步之间传递

时间步 t-1        时间步 t         时间步 t+1
   ↓                ↓                ↓
  x_{t-1}          x_t             x_{t+1}
   ↓                ↓                ↓
  [RNN] ──h_{t-1}→ [RNN] ──h_t→   [RNN]
   ↓                ↓                ↓
  y_{t-1}          y_t             y_{t+1}

数学表达:

关键特性:

  1. 参数共享: 所有时间步共享权重
  2. 记忆机制: 隐状态 编码历史信息
  3. 可变长度: 理论上可处理任意长度序列

1.3 RNN 的输入输出模式

模式输入输出应用场景示例
一对一单个单个传统任务图像分类
一对多单个序列序列生成图像描述生成
多对一序列单个序列分类情感分析、视频分类
多对多 (同步)序列序列逐帧标注视频帧分类、POS 标注
多对多 (异步)序列序列序列到序列转换机器翻译、语音识别

2. Vanilla RNN 原理与问题

2.1 基础架构

前向传播:

# 伪代码
for t in range(T):
    h[t] = tanh(W_hh @ h[t-1] + W_xh @ x[t] + b_h)
    y[t] = W_hy @ h[t] + b_y

参数维度:

  • 输入维度:
  • 隐藏维度:
  • 输出维度:
  • :
  • :
  • :

2.2 反向传播:BPTT(时间反向传播)

挑战: 梯度需要沿时间反向传播

梯度公式:

其中每个 涉及链式法则:

核心问题: 包含 个矩阵乘积!

2.3 梯度消失与梯度爆炸

2.3.1 梯度消失(Vanishing Gradient)

原因分析:

  • ,通常 < 0.25
  • 连乘导致 很大

后果:

  • 长距离依赖无法学习
  • 实际有效记忆长度 < 10 步
  • 早期时间步梯度趋近于零

示例:

句子: "The cat, which was very fluffy and cute, was sitting on the mat."
问题: "was" 的主语是 "cat"(距离 8 个词)
结果: Vanilla RNN 难以建立这种联系

2.3.2 梯度爆炸(Exploding Gradient)

原因: 如果 的最大特征值 > 1,梯度指数增长

后果:

  • 参数更新过大
  • 训练不稳定,loss 震荡
  • 数值溢出 (NaN)

解决方案:

  1. 梯度裁剪(Gradient Clipping):

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
  2. 权重正则化:

    • 限制 的谱范数

2.4 其他问题

问题描述影响
缓慢训练序列计算无法并行化训练时间长
长序列困难隐状态容量有限信息瓶颈
短期记忆偏向近期信息主导长距离依赖丢失
激活函数饱和Tanh 在极值区梯度接近 0加剧梯度消失

3. LSTM:长短期记忆网络

3.1 核心创新

作者: Sepp Hochreiter & Jürgen Schmidhuber (1997)
论文: “Long Short-Term Memory”

关键思想: 使用门控机制选择性地记忆和遗忘信息

3.2 LSTM 架构详解

3.2.1 组件概览

输入: x_t, h_{t-1}, c_{t-1}
输出: h_t, c_t

c_{t-1} ───────────────────→ c_t (细胞状态,长期记忆)
           ×         +
          [遗忘门]  [输入门]
           ↑         ↑
         x_t, h_{t-1}

h_{t-1} ────→ [输出门] ────→ h_t (隐藏状态,短期记忆)
                ↑
              tanh(c_t)

3.2.2 四个核心门

1. 遗忘门(Forget Gate): 决定丢弃多少旧记忆

  • 输出:
  • 0 = 完全遗忘,1 = 完全保留

2. 输入门(Input Gate): 决定新信息写入多少

  • :新信息的权重
  • :候选记忆(新信息内容)

3. 细胞状态更新: 融合旧记忆和新信息

  • :逐元素乘法(Hadamard 积)
  • :保留的旧记忆
  • :添加的新记忆

4. 输出门(Output Gate): 决定输出多少信息

3.2.3 完整前向传播

# 完整 LSTM 单元伪代码
def lstm_cell(x_t, h_prev, C_prev, W_f, W_i, W_C, W_o):
    # 拼接输入
    concat = torch.cat([h_prev, x_t], dim=-1)
    
    # 遗忘门
    f_t = torch.sigmoid(W_f @ concat + b_f)
    
    # 输入门
    i_t = torch.sigmoid(W_i @ concat + b_i)
    C_tilde = torch.tanh(W_C @ concat + b_C)
    
    # 更新细胞状态
    C_t = f_t * C_prev + i_t * C_tilde
    
    # 输出门
    o_t = torch.sigmoid(W_o @ concat + b_o)
    h_t = o_t * torch.tanh(C_t)
    
    return h_t, C_t

3.3 为什么 LSTM 有效?

3.3.1 缓解梯度消失

关键: 细胞状态 的更新是加法,非乘法

  • 梯度直接通过 (接近 1)传递
  • 避免长距离连乘导致的消失
  • 称为”梯度高速公路”(Gradient Highway)

对比:

  • Vanilla RNN: (连乘)
  • LSTM: (加法路径)

3.3.2 选择性记忆

门的作用:

  • 遗忘门 = 低: 丢弃无关信息(如填充词 “the”, “a”)
  • 输入门 = 高: 记住关键信息(如人名、数字)
  • 输出门: 控制短期输出,不影响长期记忆

示例:情感分析

输入: "The movie was not very good but the acting was great."
遗忘门: 低权重给 "not very good"(否定)
输入门: 高权重给 "acting was great"(正面)
输出: 正面情感(输出门强调后半句)

3.4 LSTM 变体

3.4.1 Peephole Connections (2000)

创新: 让门可以”窥视”细胞状态

效果: 提升时序精确任务(如计数)

3.4.2 Coupled Forget-Input Gates

简化: (遗忘多少就记住多少)

优势: 减少参数,加速训练


4. GRU:门控循环单元

4.1 核心思想

作者: Kyunghyun Cho et al. (2014)
论文: “Learning Phrase Representations using RNN Encoder-Decoder”

设计目标: 简化 LSTM,性能相当但参数更少

4.2 GRU 架构

4.2.1 结构对比

组件LSTMGRU
门数量3 个(f, i, o)2 个(r, z)
状态变量2 个(h, C)1 个(h)
参数量4 组权重3 组权重

4.2.2 公式详解

1. 重置门(Reset Gate): 决定遗忘多少历史

2. 更新门(Update Gate): 决定新旧信息比例

3. 候选隐状态: 新信息内容

4. 最终隐状态: 融合新旧信息

4.2.3 直观理解

z_t = 0.8  (更新门高)
├─ 80% 新信息 (h_tilde)
└─ 20% 旧信息 (h_{t-1})
→ 倾向于记住新信息

z_t = 0.2  (更新门低)
├─ 20% 新信息
└─ 80% 旧信息
→ 倾向于保留历史

4.3 GRU vs LSTM

维度LSTMGRU胜者
参数量GRU
训练速度快 20-30%GRU
长距离依赖强(专用 C)稍弱LSTM
小数据集易过拟合泛化更好GRU
大数据集性能更优持平LSTM
可解释性复杂(4 个门)简单(2 个门)GRU

实践建议:

  • 优先尝试 GRU: 训练快,参数少,多数任务性能相当
  • 长序列或大数据: 考虑 LSTM
  • 受限硬件: GRU 更轻量

4.4 门控机制的本质

共同点:所有门控 RNN 都在解决:

  1. 梯度流动: 提供”高速公路”(加法路径)
  2. 选择性记忆: 动态决定记忆/遗忘
  3. 自适应: 不同时间步不同策略

门的数学意义:

  • 软注意力机制(Soft Attention)
  • 可微分的 if-else 逻辑
  • 0-1 之间的加权插值

5. RNN 变体与优化

5.1 双向 RNN(Bidirectional RNN)

动机: 许多任务需要”未来”信息(如词性标注)

架构:

前向 RNN: h_t^{→} = f(h_{t-1}^{→}, x_t)
后向 RNN: h_t^{←} = f(h_{t+1}^{←}, x_t)
输出:     h_t = [h_t^{→}; h_t^{←}]  # 拼接

应用:

  • 命名实体识别(NER)
  • 词性标注(POS Tagging)
  • 问答系统(需要全句理解)

限制: 无法用于实时生成(需要完整序列)

5.2 深层 RNN(Stacked/Deep RNN)

多层堆叠:

层 3: h_t^{(3)} = RNN(h_t^{(2)}, h_{t-1}^{(3)})
层 2: h_t^{(2)} = RNN(h_t^{(1)}, h_{t-1}^{(2)})
层 1: h_t^{(1)} = RNN(x_t, h_{t-1}^{(1)})

经验:

  • 2-4 层最常见
  • 过深容易过拟合(需 Dropout)
  • NLP:2 层通常足够
  • 语音识别:5-7 层

5.3 注意力机制 + RNN

Seq2Seq + Attention (2015):

  • 编码器:双向 LSTM
  • 注意力:动态加权源序列
  • 解码器:单向 LSTM

Bahdanau Attention:

影响: 成为 Transformer 的前身

5.4 其他变体

变体创新应用
IndRNN独立循环(避免梯度消失)超长序列
QRNN并行化卷积 + 循环加速推理
SRU简化循环单元(比 GRU 更简单)快速训练
Clockwork RNN多尺度时钟(不同更新频率)多时间尺度建模

6. 应用场景与实践

6.1 经典应用

6.1.1 语言建模(Language Modeling)

任务: 预测下一个词

架构:

embedding → LSTM(2 层) → Linear → Softmax

指标: 困惑度(Perplexity)=

2024 状态: 基本被 Transformer 取代

6.1.2 机器翻译(Seq2Seq)

Encoder-Decoder 架构:

英文输入 → Encoder LSTM → 上下文向量 → Decoder LSTM → 法文输出

2024 状态: Transformer 完全主导(BLEU 高 10+ 分)

6.1.3 语音识别(ASR)

输入: MFCC 特征序列
输出: 文字转录

架构: 双向 LSTM + CTC Loss(Connectionist Temporal Classification)

2024 状态:

  • 仍在使用(Whisper 等模型内部有 LSTM)
  • 混合架构:CNN(声学特征)+ LSTM(时序)+ Transformer(语言模型)

6.1.4 时间序列预测

应用:

  • 股票价格预测
  • 天气预报
  • 能源需求预测

2024 状态: LSTM 仍是主流(简单有效)

替代方案:

  • Temporal Convolutional Networks (TCN): 并行化训练
  • Transformer + 位置编码: 长距离依赖
  • N-BEATS: 纯 MLP 架构

6.2 训练技巧

6.2.1 梯度裁剪(必用)

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

6.2.2 序列打包(Packing)

问题: 变长序列批处理需要填充(浪费计算)

解决:

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
 
# 打包(跳过填充位置)
packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
output, (h_n, c_n) = lstm(packed)
 
# 解包
output, _ = pad_packed_sequence(output, batch_first=True)

6.2.3 初始化策略

推荐:

  • Xavier/Glorot: 标准初始化
  • 正交初始化: 使用正交矩阵(缓解梯度问题)
for name, param in lstm.named_parameters():
    if 'weight_hh' in name:
        nn.init.orthogonal_(param)
    elif 'weight_ih' in name:
        nn.init.xavier_uniform_(param)
    elif 'bias' in name:
        nn.init.zeros_(param)

6.2.4 学习率策略

经验:

  • 初始学习率:1e-3 到 1e-2
  • 使用学习率衰减(每 N epoch 减半)
  • 或使用 ReduceLROnPlateau

6.2.5 Dropout 位置

推荐:

nn.LSTM(..., dropout=0.3)  # 层间 Dropout(多层 RNN)
nn.Dropout(0.5)            # 输出层之前

注意: 不要在时间步之间使用 Dropout(破坏时序)

6.3 常见错误与调试

问题可能原因解决方案
Loss = NaN梯度爆炸梯度裁剪 + 降低学习率
验证 loss 不下降过拟合Dropout + 正则化
训练极慢序列过长 / 批次太小截断序列 / 增大批次
预测全是同一个值陷入局部最优重新初始化 / 调整架构
显存溢出序列太长截断 BPTT / 使用 FP16
长距离依赖学不到梯度消失换 LSTM/GRU / 加注意力

7. RNN vs Transformer

7.1 核心差异

维度RNN/LSTM/GRUTransformer
并行化❌ 序列计算✅ 全并行
训练速度慢(10-100×)
长距离依赖困难(路径长 O(n))容易(直接连接 O(1))
内存需求低(O(n))高(O(n²))
推理速度快(O(1) 每步)慢(自回归 O(n))
位置信息天然(顺序处理)需显式编码
可解释性门可视化注意力图
小数据表现好(归纳偏置强)差(需大量数据)
超长序列困难(信息瓶颈)困难(计算复杂度)

7.2 性能对比

机器翻译(WMT’14 英德):

模型BLEU训练时间
LSTM Seq2Seq25.214 天
LSTM + Attention28.410 天
Transformer Base27.33 天
Transformer Big28.45 天

语言建模(Penn Treebank):

模型困惑度参数量
LSTM (3 层)78.424M
LSTM + AWD (正则化)57.324M
Transformer (12 层)56.447M

7.3 为什么 Transformer 胜出?

1. 并行化:

  • RNN:必须等 计算完才能算
  • Transformer:所有位置同时计算

2. 长距离依赖:

  • RNN:信息经过 步传播,路径长
  • Transformer:任意位置直接连接,路径长度 = 1

3. 表达能力:

  • RNN:固定容量的隐状态瓶颈
  • Transformer:注意力动态选择相关信息

4. 工程优化:

  • Transformer 充分利用 GPU 并行(矩阵乘法)
  • RNN 的顺序计算难以加速

7.4 RNN 的剩余优势

仍然适合 RNN 的场景:

  1. 实时流式处理

    • 在线语音识别(边说边转录)
    • 实时传感器数据(IoT)
    • 逐帧视频分析
  2. 极长序列(百万步)

    • DNA 序列分析
    • 长期时间序列(年级别)
    • Transformer 的 O(n²) 复杂度无法承受
  3. 内存受限环境

    • 边缘设备(手机、IoT)
    • LSTM 只需 O(n) 内存
  4. 小数据集

    • RNN 的归纳偏置(顺序性)有助于泛化
    • Transformer 需大量数据

8. 2024-2025:RNN 的复兴?

8.1 State Space Models (SSMs) 🔥

8.1.1 核心思想

问题: RNN 太慢,Transformer 太贵

解决: 用线性状态空间模型替代非线性 RNN

连续时间 SSM:

离散化后:

8.1.2 代表模型

S4 (Structured State Spaces, 2022):

  • 使用特殊结构的 矩阵(HiPPO 初始化)
  • 并行训练(卷积形式),递归推理

性能:

  • Long Range Arena:超越 Transformer
  • 序列长度 16K+:可行

Mamba (2023): 🔥

  • 选择性 SSM(输入依赖的参数)
  • 线性复杂度 O(n)
  • 在语言建模上逼近 Transformer

8.1.3 优势

对比项TransformerMamba/SSM
训练复杂度O(n²)O(n)
推理复杂度O(n)O(1)
长序列能力受限
参数效率

8.2 RWKV (2023) 🔥

全称: Receptance Weighted Key Value

创新: RNN 形式 + Transformer 性能

架构:

  • 训练:并行化(类似 Transformer)
  • 推理:递归(O(1) 每步)

性能:

  • RWKV-7B:接近 GPT-3 级别
  • 推理速度:比 Transformer 快 10×

8.3 RetNet (2023) 🔥

全称: Retentive Network

创新: 三种等价形式

  1. 并行: 训练时快速
  2. 递归: 推理时 O(1)
  3. 块递归: 平衡两者

性能:

  • 训练速度:接近 Transformer
  • 推理速度:接近 RNN
  • 性能:与 Transformer 持平

8.4 为什么 RNN 复兴?

驱动因素:

  1. 长上下文需求: GPT-4 需要 128K tokens,Transformer O(n²) 不可承受
  2. 推理成本: 大模型推理成本高昂,RNN 的 O(1) 推理吸引人
  3. 理论突破: SSM 提供了数学上优雅的替代方案

2025 趋势:

  • 混合架构: Transformer(局部)+ SSM(全局)
  • 特定领域: 长序列任务(DNA、音频)重回 RNN 风格
  • 边缘部署: 内存高效的 LSTM/GRU 依然重要

9. 代码实现

9.1 从零实现 Vanilla RNN

import torch
import torch.nn as nn
 
class VanillaRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # 权重矩阵
        self.W_xh = nn.Linear(input_size, hidden_size)
        self.W_hh = nn.Linear(hidden_size, hidden_size)
        self.W_hy = nn.Linear(hidden_size, output_size)
        
    def forward(self, x, h_prev=None):
        """
        x: (batch, seq_len, input_size)
        h_prev: (batch, hidden_size) 或 None
        返回: output (batch, seq_len, output_size), h_n (batch, hidden_size)
        """
        batch_size, seq_len, _ = x.size()
        
        # 初始化隐状态
        if h_prev is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)
        else:
            h = h_prev
        
        outputs = []
        for t in range(seq_len):
            x_t = x[:, t, :]  # (batch, input_size)
            h = torch.tanh(self.W_xh(x_t) + self.W_hh(h))
            y_t = self.W_hy(h)
            outputs.append(y_t.unsqueeze(1))
        
        output = torch.cat(outputs, dim=1)  # (batch, seq_len, output_size)
        return output, h
 
# 测试
rnn = VanillaRNN(input_size=10, hidden_size=20, output_size=5)
x = torch.randn(2, 15, 10)  # (batch=2, seq_len=15, input_size=10)
output, h_n = rnn(x)
print(f"Output shape: {output.shape}")  # (2, 15, 5)
print(f"Final hidden state: {h_n.shape}")  # (2, 20)

9.2 从零实现 LSTM

class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # 4 个门的权重(合并计算)
        self.W_ih = nn.Linear(input_size, 4 * hidden_size)
        self.W_hh = nn.Linear(hidden_size, 4 * hidden_size)
        
    def forward(self, x, states=None):
        """
        x: (batch, input_size)
        states: (h_prev, C_prev) 或 None
        """
        batch_size = x.size(0)
        
        if states is None:
            h_prev = torch.zeros(batch_size, self.hidden_size, device=x.device)
            C_prev = torch.zeros(batch_size, self.hidden_size, device=x.device)
        else:
            h_prev, C_prev = states
        
        # 一次性计算 4 个门
        gates = self.W_ih(x) + self.W_hh(h_prev)  # (batch, 4*hidden_size)
        
        # 分割成 4 个门
        i, f, g, o = gates.chunk(4, dim=1)
        
        # 应用激活函数
        i = torch.sigmoid(i)  # 输入门
        f = torch.sigmoid(f)  # 遗忘门
        g = torch.tanh(g)     # 候选记忆
        o = torch.sigmoid(o)  # 输出门
        
        # 更新细胞状态和隐状态
        C = f * C_prev + i * g
        h = o * torch.tanh(C)
        
        return h, (h, C)
 
 
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        
        self.cells = nn.ModuleList([
            LSTMCell(input_size if i == 0 else hidden_size, hidden_size)
            for i in range(num_layers)
        ])
        
    def forward(self, x, states=None):
        """
        x: (batch, seq_len, input_size)
        states: [(h, C) for each layer] 或 None
        """
        batch_size, seq_len, _ = x.size()
        
        if states is None:
            states = [None] * self.num_layers
        
        outputs = []
        for t in range(seq_len):
            x_t = x[:, t, :]
            
            # 逐层处理
            for layer, cell in enumerate(self.cells):
                x_t, states[layer] = cell(x_t, states[layer])
            
            outputs.append(x_t.unsqueeze(1))
        
        output = torch.cat(outputs, dim=1)
        h_n = torch.stack([s[0] for s in states])  # (num_layers, batch, hidden_size)
        C_n = torch.stack([s[1] for s in states])
        
        return output, (h_n, C_n)
 
# 测试
lstm = LSTM(input_size=10, hidden_size=20, num_layers=2)
x = torch.randn(2, 15, 10)
output, (h_n, C_n) = lstm(x)
print(f"Output: {output.shape}")    # (2, 15, 20)
print(f"h_n: {h_n.shape}")          # (2, 2, 20)
print(f"C_n: {C_n.shape}")          # (2, 2, 20)

9.3 使用 PyTorch 内置 LSTM

import torch
import torch.nn as nn
 
class TextClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers=2, dropout=0.5):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=n_layers,
            dropout=dropout if n_layers > 1 else 0,
            batch_first=True,
            bidirectional=True  # 双向 LSTM
        )
        
        self.fc = nn.Linear(hidden_dim * 2, output_dim)  # *2 因为双向
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, text, text_lengths):
        """
        text: (batch, seq_len)
        text_lengths: (batch,) 每个样本的实际长度
        """
        # 嵌入
        embedded = self.dropout(self.embedding(text))  # (batch, seq_len, emb_dim)
        
        # 打包序列(跳过填充)
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, text_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        
        # LSTM
        packed_output, (h_n, c_n) = self.lstm(packed)
        
        # 解包
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
        
        # 使用最后的隐状态(双向拼接)
        h_fwd = h_n[-2]  # 前向最后一层
        h_bwd = h_n[-1]  # 后向最后一层
        hidden = torch.cat([h_fwd, h_bwd], dim=1)  # (batch, hidden_dim*2)
        
        # 分类
        output = self.fc(self.dropout(hidden))  # (batch, output_dim)
        
        return output
 
# 实例化
model = TextClassifier(
    vocab_size=10000,
    embedding_dim=100,
    hidden_dim=256,
    output_dim=2,  # 二分类
    n_layers=2,
    dropout=0.5
)
 
# 示例输入
text = torch.randint(0, 10000, (4, 20))  # (batch=4, seq_len=20)
lengths = torch.tensor([20, 15, 10, 18])  # 实际长度
output = model(text, lengths)
print(f"Predictions: {output.shape}")  # (4, 2)

9.4 Seq2Seq 机器翻译

class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hidden_dim, n_layers, dropout=dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, (h_n, c_n) = self.rnn(embedded)
        return h_n, c_n
 
 
class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hidden_dim, n_layers, dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, trg, hidden, cell):
        embedded = self.dropout(self.embedding(trg))
        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        prediction = self.fc(output)
        return prediction, hidden, cell
 
 
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        """
        src: (batch, src_len)
        trg: (batch, trg_len)
        """
        batch_size = src.size(0)
        trg_len = trg.size(1)
        trg_vocab_size = self.decoder.fc.out_features
        
        # 存储输出
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size, device=src.device)
        
        # 编码
        hidden, cell = self.encoder(src)
        
        # 解码(第一步是 <sos> token)
        input_token = trg[:, 0].unsqueeze(1)
        
        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input_token, hidden, cell)
            outputs[:, t] = output.squeeze(1)
            
            # Teacher forcing
            use_teacher = torch.rand(1).item() < teacher_forcing_ratio
            top1 = output.argmax(2)
            input_token = trg[:, t].unsqueeze(1) if use_teacher else top1
        
        return outputs
 
# 实例化
encoder = Encoder(vocab_size=5000, emb_dim=256, hidden_dim=512, n_layers=2, dropout=0.5)
decoder = Decoder(vocab_size=5000, emb_dim=256, hidden_dim=512, n_layers=2, dropout=0.5)
model = Seq2Seq(encoder, decoder)
 
# 示例
src = torch.randint(0, 5000, (4, 10))  # (batch, src_len)
trg = torch.randint(0, 5000, (4, 12))  # (batch, trg_len)
output = model(src, trg)
print(f"Output: {output.shape}")  # (4, 12, 5000)

9.5 训练循环示例

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
 
def train_epoch(model, dataloader, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    
    for batch in dataloader:
        src, trg = batch
        src, trg = src.cuda(), trg.cuda()
        
        optimizer.zero_grad()
        
        # 前向传播
        output = model(src, trg)
        
        # 计算损失(忽略 <pad> token)
        output_dim = output.shape[-1]
        output = output[:, 1:].reshape(-1, output_dim)
        trg = trg[:, 1:].reshape(-1)
        
        loss = criterion(output, trg)
        
        # 反向传播
        loss.backward()
        
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
    
    return epoch_loss / len(dataloader)
 
 
def evaluate(model, dataloader, criterion):
    model.eval()
    epoch_loss = 0
    
    with torch.no_grad():
        for batch in dataloader:
            src, trg = batch
            src, trg = src.cuda(), trg.cuda()
            
            output = model(src, trg, teacher_forcing_ratio=0)  # 验证时不用 teacher forcing
            
            output_dim = output.shape[-1]
            output = output[:, 1:].reshape(-1, output_dim)
            trg = trg[:, 1:].reshape(-1)
            
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    
    return epoch_loss / len(dataloader)
 
 
# 主训练循环
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)  # 忽略填充
 
N_EPOCHS = 10
CLIP = 1.0
 
for epoch in range(N_EPOCHS):
    train_loss = train_epoch(model, train_loader, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_loader, criterion)
    
    print(f"Epoch {epoch+1}")
    print(f"  Train Loss: {train_loss:.3f} | Train PPL: {torch.exp(torch.tensor(train_loss)):.3f}")
    print(f"  Valid Loss: {valid_loss:.3f} | Valid PPL: {torch.exp(torch.tensor(valid_loss)):.3f}")

10. 总结

10.1 核心要点回顾

RNN 家族演进:

Vanilla RNN (梯度问题) 
    ↓
LSTM (门控机制解决梯度消失)
    ↓
GRU (简化 LSTM)
    ↓
Attention RNN (增强长距离依赖)
    ↓
Transformer (完全取代 RNN)
    ↓
SSM/Mamba (RNN 复兴,线性复杂度)

关键概念:

  1. 序列建模: 处理时间依赖的自然选择
  2. 梯度问题: RNN 的核心挑战
  3. 门控机制: LSTM/GRU 的解决方案
  4. 并行化限制: 导致 Transformer 崛起
  5. 复兴趋势: SSM 提供新可能

10.2 实践建议

2025 年如何选择:

场景推荐方案理由
新 NLP 项目Transformer性能最优,生态成熟
实时流式处理LSTM/GRUO(1) 推理,无需完整序列
时间序列预测LSTM + TCN简单有效,可解释性强
超长序列(100K+)Mamba / RetNet线性复杂度
小数据集LSTM + 数据增强归纳偏置强
边缘设备GRU 量化版本内存高效
科研探索SSM / RWKV前沿方向

10.3 学习路径

初学者:

  1. 理解 Vanilla RNN(掌握循环思想)
  2. 学习 LSTM(门控机制)
  3. 实现简单项目(情感分析、时间序列)

进阶:

  1. 深入 BPTT 和梯度问题
  2. Seq2Seq + Attention
  3. 对比 RNN 与 Transformer

前沿:

  1. State Space Models (S4, Mamba)
  2. 混合架构设计
  3. 长上下文建模

11. 扩展阅读

11.1 必读论文

经典:

  • LSTM: “Long Short-Term Memory” (Hochreiter, 1997)
  • GRU: “Learning Phrase Representations using RNN Encoder-Decoder” (Cho, 2014)
  • Seq2Seq: “Sequence to Sequence Learning with Neural Networks” (Sutskever, 2014)
  • Attention: “Neural Machine Translation by Jointly Learning to Align and Translate” (Bahdanau, 2015)

现代:

  • S4: “Efficiently Modeling Long Sequences with Structured State Spaces” (Gu, 2022)
  • Mamba: “Mamba: Linear-Time Sequence Modeling with Selective State Spaces” (Gu, 2023) 🔥
  • RWKV: “RWKV: Reinventing RNNs for the Transformer Era” (Peng, 2023) 🔥
  • RetNet: “Retentive Network: A Successor to Transformer for Large Language Models” (Sun, 2023) 🔥

11.2 开源资源

教程:

代码库:

前沿模型:


12. 相关笔记


最后更新: 2025-01-07
维护者: sean2077