循环神经网络(RNN/LSTM/GRU)
核心定位: 处理序列数据的经典架构,从 1986 年 RNN 到 2023 年被 Transformer 大规模取代的演进史。 关联笔记: 神经网络、神经网络层类型详解、Transformer、卷积神经网络
1. 为什么需要 RNN
1.1 序列数据的挑战
序列数据无处不在:
- 文本:单词序列
- 语音:音频帧序列
- 视频:图像帧序列
- 时间序列:股票价格、气温变化
传统神经网络的局限:
- 固定输入长度: 无法处理可变长度序列
- 无记忆能力: 每个输入独立处理
- 参数爆炸: 长序列需要巨量参数
1.2 RNN 的核心思想
循环连接: 隐状态在时间步之间传递
时间步 t-1 时间步 t 时间步 t+1
↓ ↓ ↓
x_{t-1} x_t x_{t+1}
↓ ↓ ↓
[RNN] ──h_{t-1}→ [RNN] ──h_t→ [RNN]
↓ ↓ ↓
y_{t-1} y_t y_{t+1}
数学表达:
关键特性:
- 参数共享: 所有时间步共享权重
- 记忆机制: 隐状态 编码历史信息
- 可变长度: 理论上可处理任意长度序列
1.3 RNN 的输入输出模式
| 模式 | 输入 | 输出 | 应用场景 | 示例 |
|---|---|---|---|---|
| 一对一 | 单个 | 单个 | 传统任务 | 图像分类 |
| 一对多 | 单个 | 序列 | 序列生成 | 图像描述生成 |
| 多对一 | 序列 | 单个 | 序列分类 | 情感分析、视频分类 |
| 多对多 (同步) | 序列 | 序列 | 逐帧标注 | 视频帧分类、POS 标注 |
| 多对多 (异步) | 序列 | 序列 | 序列到序列转换 | 机器翻译、语音识别 |
2. Vanilla RNN 原理与问题
2.1 基础架构
前向传播:
# 伪代码
for t in range(T):
h[t] = tanh(W_hh @ h[t-1] + W_xh @ x[t] + b_h)
y[t] = W_hy @ h[t] + b_y参数维度:
- 输入维度:
- 隐藏维度:
- 输出维度:
- :
- :
- :
2.2 反向传播:BPTT(时间反向传播)
挑战: 梯度需要沿时间反向传播
梯度公式:
其中每个 涉及链式法则:
核心问题: 包含 个矩阵乘积!
2.3 梯度消失与梯度爆炸
2.3.1 梯度消失(Vanishing Gradient)
原因分析:
- ,通常 < 0.25
- 连乘导致 当 很大
后果:
- 长距离依赖无法学习
- 实际有效记忆长度 < 10 步
- 早期时间步梯度趋近于零
示例:
句子: "The cat, which was very fluffy and cute, was sitting on the mat."
问题: "was" 的主语是 "cat"(距离 8 个词)
结果: Vanilla RNN 难以建立这种联系
2.3.2 梯度爆炸(Exploding Gradient)
原因: 如果 的最大特征值 > 1,梯度指数增长
后果:
- 参数更新过大
- 训练不稳定,loss 震荡
- 数值溢出 (NaN)
解决方案:
-
梯度裁剪(Gradient Clipping):
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) -
权重正则化:
- 限制 的谱范数
2.4 其他问题
| 问题 | 描述 | 影响 |
|---|---|---|
| 缓慢训练 | 序列计算无法并行化 | 训练时间长 |
| 长序列困难 | 隐状态容量有限 | 信息瓶颈 |
| 短期记忆偏向 | 近期信息主导 | 长距离依赖丢失 |
| 激活函数饱和 | Tanh 在极值区梯度接近 0 | 加剧梯度消失 |
3. LSTM:长短期记忆网络
3.1 核心创新
作者: Sepp Hochreiter & Jürgen Schmidhuber (1997)
论文: “Long Short-Term Memory”
关键思想: 使用门控机制选择性地记忆和遗忘信息
3.2 LSTM 架构详解
3.2.1 组件概览
输入: x_t, h_{t-1}, c_{t-1}
输出: h_t, c_t
c_{t-1} ───────────────────→ c_t (细胞状态,长期记忆)
× +
[遗忘门] [输入门]
↑ ↑
x_t, h_{t-1}
h_{t-1} ────→ [输出门] ────→ h_t (隐藏状态,短期记忆)
↑
tanh(c_t)
3.2.2 四个核心门
1. 遗忘门(Forget Gate): 决定丢弃多少旧记忆
- 输出:
- 0 = 完全遗忘,1 = 完全保留
2. 输入门(Input Gate): 决定新信息写入多少
- :新信息的权重
- :候选记忆(新信息内容)
3. 细胞状态更新: 融合旧记忆和新信息
- :逐元素乘法(Hadamard 积)
- :保留的旧记忆
- :添加的新记忆
4. 输出门(Output Gate): 决定输出多少信息
3.2.3 完整前向传播
# 完整 LSTM 单元伪代码
def lstm_cell(x_t, h_prev, C_prev, W_f, W_i, W_C, W_o):
# 拼接输入
concat = torch.cat([h_prev, x_t], dim=-1)
# 遗忘门
f_t = torch.sigmoid(W_f @ concat + b_f)
# 输入门
i_t = torch.sigmoid(W_i @ concat + b_i)
C_tilde = torch.tanh(W_C @ concat + b_C)
# 更新细胞状态
C_t = f_t * C_prev + i_t * C_tilde
# 输出门
o_t = torch.sigmoid(W_o @ concat + b_o)
h_t = o_t * torch.tanh(C_t)
return h_t, C_t3.3 为什么 LSTM 有效?
3.3.1 缓解梯度消失
关键: 细胞状态 的更新是加法,非乘法
- 梯度直接通过 (接近 1)传递
- 避免长距离连乘导致的消失
- 称为”梯度高速公路”(Gradient Highway)
对比:
- Vanilla RNN: (连乘)
- LSTM: (加法路径)
3.3.2 选择性记忆
门的作用:
- 遗忘门 = 低: 丢弃无关信息(如填充词 “the”, “a”)
- 输入门 = 高: 记住关键信息(如人名、数字)
- 输出门: 控制短期输出,不影响长期记忆
示例:情感分析
输入: "The movie was not very good but the acting was great."
遗忘门: 低权重给 "not very good"(否定)
输入门: 高权重给 "acting was great"(正面)
输出: 正面情感(输出门强调后半句)
3.4 LSTM 变体
3.4.1 Peephole Connections (2000)
创新: 让门可以”窥视”细胞状态
效果: 提升时序精确任务(如计数)
3.4.2 Coupled Forget-Input Gates
简化: (遗忘多少就记住多少)
优势: 减少参数,加速训练
4. GRU:门控循环单元
4.1 核心思想
作者: Kyunghyun Cho et al. (2014)
论文: “Learning Phrase Representations using RNN Encoder-Decoder”
设计目标: 简化 LSTM,性能相当但参数更少
4.2 GRU 架构
4.2.1 结构对比
| 组件 | LSTM | GRU |
|---|---|---|
| 门数量 | 3 个(f, i, o) | 2 个(r, z) |
| 状态变量 | 2 个(h, C) | 1 个(h) |
| 参数量 | 4 组权重 | 3 组权重 |
4.2.2 公式详解
1. 重置门(Reset Gate): 决定遗忘多少历史
2. 更新门(Update Gate): 决定新旧信息比例
3. 候选隐状态: 新信息内容
4. 最终隐状态: 融合新旧信息
4.2.3 直观理解
z_t = 0.8 (更新门高)
├─ 80% 新信息 (h_tilde)
└─ 20% 旧信息 (h_{t-1})
→ 倾向于记住新信息
z_t = 0.2 (更新门低)
├─ 20% 新信息
└─ 80% 旧信息
→ 倾向于保留历史
4.3 GRU vs LSTM
| 维度 | LSTM | GRU | 胜者 |
|---|---|---|---|
| 参数量 | GRU | ||
| 训练速度 | 慢 | 快 20-30% | GRU |
| 长距离依赖 | 强(专用 C) | 稍弱 | LSTM |
| 小数据集 | 易过拟合 | 泛化更好 | GRU |
| 大数据集 | 性能更优 | 持平 | LSTM |
| 可解释性 | 复杂(4 个门) | 简单(2 个门) | GRU |
实践建议:
- 优先尝试 GRU: 训练快,参数少,多数任务性能相当
- 长序列或大数据: 考虑 LSTM
- 受限硬件: GRU 更轻量
4.4 门控机制的本质
共同点:所有门控 RNN 都在解决:
- 梯度流动: 提供”高速公路”(加法路径)
- 选择性记忆: 动态决定记忆/遗忘
- 自适应: 不同时间步不同策略
门的数学意义:
- 软注意力机制(Soft Attention)
- 可微分的 if-else 逻辑
- 0-1 之间的加权插值
5. RNN 变体与优化
5.1 双向 RNN(Bidirectional RNN)
动机: 许多任务需要”未来”信息(如词性标注)
架构:
前向 RNN: h_t^{→} = f(h_{t-1}^{→}, x_t)
后向 RNN: h_t^{←} = f(h_{t+1}^{←}, x_t)
输出: h_t = [h_t^{→}; h_t^{←}] # 拼接
应用:
- 命名实体识别(NER)
- 词性标注(POS Tagging)
- 问答系统(需要全句理解)
限制: 无法用于实时生成(需要完整序列)
5.2 深层 RNN(Stacked/Deep RNN)
多层堆叠:
层 3: h_t^{(3)} = RNN(h_t^{(2)}, h_{t-1}^{(3)})
层 2: h_t^{(2)} = RNN(h_t^{(1)}, h_{t-1}^{(2)})
层 1: h_t^{(1)} = RNN(x_t, h_{t-1}^{(1)})
经验:
- 2-4 层最常见
- 过深容易过拟合(需 Dropout)
- NLP:2 层通常足够
- 语音识别:5-7 层
5.3 注意力机制 + RNN
Seq2Seq + Attention (2015):
- 编码器:双向 LSTM
- 注意力:动态加权源序列
- 解码器:单向 LSTM
Bahdanau Attention:
影响: 成为 Transformer 的前身
5.4 其他变体
| 变体 | 创新 | 应用 |
|---|---|---|
| IndRNN | 独立循环(避免梯度消失) | 超长序列 |
| QRNN | 并行化卷积 + 循环 | 加速推理 |
| SRU | 简化循环单元(比 GRU 更简单) | 快速训练 |
| Clockwork RNN | 多尺度时钟(不同更新频率) | 多时间尺度建模 |
6. 应用场景与实践
6.1 经典应用
6.1.1 语言建模(Language Modeling)
任务: 预测下一个词
架构:
embedding → LSTM(2 层) → Linear → Softmax指标: 困惑度(Perplexity)=
2024 状态: 基本被 Transformer 取代
6.1.2 机器翻译(Seq2Seq)
Encoder-Decoder 架构:
英文输入 → Encoder LSTM → 上下文向量 → Decoder LSTM → 法文输出
2024 状态: Transformer 完全主导(BLEU 高 10+ 分)
6.1.3 语音识别(ASR)
输入: MFCC 特征序列
输出: 文字转录
架构: 双向 LSTM + CTC Loss(Connectionist Temporal Classification)
2024 状态:
- 仍在使用(Whisper 等模型内部有 LSTM)
- 混合架构:CNN(声学特征)+ LSTM(时序)+ Transformer(语言模型)
6.1.4 时间序列预测
应用:
- 股票价格预测
- 天气预报
- 能源需求预测
2024 状态: LSTM 仍是主流(简单有效)
替代方案:
- Temporal Convolutional Networks (TCN): 并行化训练
- Transformer + 位置编码: 长距离依赖
- N-BEATS: 纯 MLP 架构
6.2 训练技巧
6.2.1 梯度裁剪(必用)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)6.2.2 序列打包(Packing)
问题: 变长序列批处理需要填充(浪费计算)
解决:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# 打包(跳过填充位置)
packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
output, (h_n, c_n) = lstm(packed)
# 解包
output, _ = pad_packed_sequence(output, batch_first=True)6.2.3 初始化策略
推荐:
- Xavier/Glorot: 标准初始化
- 正交初始化: 使用正交矩阵(缓解梯度问题)
for name, param in lstm.named_parameters():
if 'weight_hh' in name:
nn.init.orthogonal_(param)
elif 'weight_ih' in name:
nn.init.xavier_uniform_(param)
elif 'bias' in name:
nn.init.zeros_(param)6.2.4 学习率策略
经验:
- 初始学习率:1e-3 到 1e-2
- 使用学习率衰减(每 N epoch 减半)
- 或使用 ReduceLROnPlateau
6.2.5 Dropout 位置
推荐:
nn.LSTM(..., dropout=0.3) # 层间 Dropout(多层 RNN)
nn.Dropout(0.5) # 输出层之前注意: 不要在时间步之间使用 Dropout(破坏时序)
6.3 常见错误与调试
| 问题 | 可能原因 | 解决方案 |
|---|---|---|
| Loss = NaN | 梯度爆炸 | 梯度裁剪 + 降低学习率 |
| 验证 loss 不下降 | 过拟合 | Dropout + 正则化 |
| 训练极慢 | 序列过长 / 批次太小 | 截断序列 / 增大批次 |
| 预测全是同一个值 | 陷入局部最优 | 重新初始化 / 调整架构 |
| 显存溢出 | 序列太长 | 截断 BPTT / 使用 FP16 |
| 长距离依赖学不到 | 梯度消失 | 换 LSTM/GRU / 加注意力 |
7. RNN vs Transformer
7.1 核心差异
| 维度 | RNN/LSTM/GRU | Transformer |
|---|---|---|
| 并行化 | ❌ 序列计算 | ✅ 全并行 |
| 训练速度 | 慢(10-100×) | 快 |
| 长距离依赖 | 困难(路径长 O(n)) | 容易(直接连接 O(1)) |
| 内存需求 | 低(O(n)) | 高(O(n²)) |
| 推理速度 | 快(O(1) 每步) | 慢(自回归 O(n)) |
| 位置信息 | 天然(顺序处理) | 需显式编码 |
| 可解释性 | 门可视化 | 注意力图 |
| 小数据表现 | 好(归纳偏置强) | 差(需大量数据) |
| 超长序列 | 困难(信息瓶颈) | 困难(计算复杂度) |
7.2 性能对比
机器翻译(WMT’14 英德):
| 模型 | BLEU | 训练时间 |
|---|---|---|
| LSTM Seq2Seq | 25.2 | 14 天 |
| LSTM + Attention | 28.4 | 10 天 |
| Transformer Base | 27.3 | 3 天 |
| Transformer Big | 28.4 | 5 天 |
语言建模(Penn Treebank):
| 模型 | 困惑度 | 参数量 |
|---|---|---|
| LSTM (3 层) | 78.4 | 24M |
| LSTM + AWD (正则化) | 57.3 | 24M |
| Transformer (12 层) | 56.4 | 47M |
7.3 为什么 Transformer 胜出?
1. 并行化:
- RNN:必须等 计算完才能算
- Transformer:所有位置同时计算
2. 长距离依赖:
- RNN:信息经过 步传播,路径长
- Transformer:任意位置直接连接,路径长度 = 1
3. 表达能力:
- RNN:固定容量的隐状态瓶颈
- Transformer:注意力动态选择相关信息
4. 工程优化:
- Transformer 充分利用 GPU 并行(矩阵乘法)
- RNN 的顺序计算难以加速
7.4 RNN 的剩余优势
✅ 仍然适合 RNN 的场景:
-
实时流式处理
- 在线语音识别(边说边转录)
- 实时传感器数据(IoT)
- 逐帧视频分析
-
极长序列(百万步)
- DNA 序列分析
- 长期时间序列(年级别)
- Transformer 的 O(n²) 复杂度无法承受
-
内存受限环境
- 边缘设备(手机、IoT)
- LSTM 只需 O(n) 内存
-
小数据集
- RNN 的归纳偏置(顺序性)有助于泛化
- Transformer 需大量数据
8. 2024-2025:RNN 的复兴?
8.1 State Space Models (SSMs) 🔥
8.1.1 核心思想
问题: RNN 太慢,Transformer 太贵
解决: 用线性状态空间模型替代非线性 RNN
连续时间 SSM:
离散化后:
8.1.2 代表模型
S4 (Structured State Spaces, 2022):
- 使用特殊结构的 矩阵(HiPPO 初始化)
- 并行训练(卷积形式),递归推理
性能:
- Long Range Arena:超越 Transformer
- 序列长度 16K+:可行
Mamba (2023): 🔥
- 选择性 SSM(输入依赖的参数)
- 线性复杂度 O(n)
- 在语言建模上逼近 Transformer
8.1.3 优势
| 对比项 | Transformer | Mamba/SSM |
|---|---|---|
| 训练复杂度 | O(n²) | O(n) |
| 推理复杂度 | O(n) | O(1) |
| 长序列能力 | 受限 | 强 |
| 参数效率 | 低 | 高 |
8.2 RWKV (2023) 🔥
全称: Receptance Weighted Key Value
创新: RNN 形式 + Transformer 性能
架构:
- 训练:并行化(类似 Transformer)
- 推理:递归(O(1) 每步)
性能:
- RWKV-7B:接近 GPT-3 级别
- 推理速度:比 Transformer 快 10×
8.3 RetNet (2023) 🔥
全称: Retentive Network
创新: 三种等价形式
- 并行: 训练时快速
- 递归: 推理时 O(1)
- 块递归: 平衡两者
性能:
- 训练速度:接近 Transformer
- 推理速度:接近 RNN
- 性能:与 Transformer 持平
8.4 为什么 RNN 复兴?
驱动因素:
- 长上下文需求: GPT-4 需要 128K tokens,Transformer O(n²) 不可承受
- 推理成本: 大模型推理成本高昂,RNN 的 O(1) 推理吸引人
- 理论突破: SSM 提供了数学上优雅的替代方案
2025 趋势:
- 混合架构: Transformer(局部)+ SSM(全局)
- 特定领域: 长序列任务(DNA、音频)重回 RNN 风格
- 边缘部署: 内存高效的 LSTM/GRU 依然重要
9. 代码实现
9.1 从零实现 Vanilla RNN
import torch
import torch.nn as nn
class VanillaRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.hidden_size = hidden_size
# 权重矩阵
self.W_xh = nn.Linear(input_size, hidden_size)
self.W_hh = nn.Linear(hidden_size, hidden_size)
self.W_hy = nn.Linear(hidden_size, output_size)
def forward(self, x, h_prev=None):
"""
x: (batch, seq_len, input_size)
h_prev: (batch, hidden_size) 或 None
返回: output (batch, seq_len, output_size), h_n (batch, hidden_size)
"""
batch_size, seq_len, _ = x.size()
# 初始化隐状态
if h_prev is None:
h = torch.zeros(batch_size, self.hidden_size, device=x.device)
else:
h = h_prev
outputs = []
for t in range(seq_len):
x_t = x[:, t, :] # (batch, input_size)
h = torch.tanh(self.W_xh(x_t) + self.W_hh(h))
y_t = self.W_hy(h)
outputs.append(y_t.unsqueeze(1))
output = torch.cat(outputs, dim=1) # (batch, seq_len, output_size)
return output, h
# 测试
rnn = VanillaRNN(input_size=10, hidden_size=20, output_size=5)
x = torch.randn(2, 15, 10) # (batch=2, seq_len=15, input_size=10)
output, h_n = rnn(x)
print(f"Output shape: {output.shape}") # (2, 15, 5)
print(f"Final hidden state: {h_n.shape}") # (2, 20)9.2 从零实现 LSTM
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
# 4 个门的权重(合并计算)
self.W_ih = nn.Linear(input_size, 4 * hidden_size)
self.W_hh = nn.Linear(hidden_size, 4 * hidden_size)
def forward(self, x, states=None):
"""
x: (batch, input_size)
states: (h_prev, C_prev) 或 None
"""
batch_size = x.size(0)
if states is None:
h_prev = torch.zeros(batch_size, self.hidden_size, device=x.device)
C_prev = torch.zeros(batch_size, self.hidden_size, device=x.device)
else:
h_prev, C_prev = states
# 一次性计算 4 个门
gates = self.W_ih(x) + self.W_hh(h_prev) # (batch, 4*hidden_size)
# 分割成 4 个门
i, f, g, o = gates.chunk(4, dim=1)
# 应用激活函数
i = torch.sigmoid(i) # 输入门
f = torch.sigmoid(f) # 遗忘门
g = torch.tanh(g) # 候选记忆
o = torch.sigmoid(o) # 输出门
# 更新细胞状态和隐状态
C = f * C_prev + i * g
h = o * torch.tanh(C)
return h, (h, C)
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1):
super().__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
self.cells = nn.ModuleList([
LSTMCell(input_size if i == 0 else hidden_size, hidden_size)
for i in range(num_layers)
])
def forward(self, x, states=None):
"""
x: (batch, seq_len, input_size)
states: [(h, C) for each layer] 或 None
"""
batch_size, seq_len, _ = x.size()
if states is None:
states = [None] * self.num_layers
outputs = []
for t in range(seq_len):
x_t = x[:, t, :]
# 逐层处理
for layer, cell in enumerate(self.cells):
x_t, states[layer] = cell(x_t, states[layer])
outputs.append(x_t.unsqueeze(1))
output = torch.cat(outputs, dim=1)
h_n = torch.stack([s[0] for s in states]) # (num_layers, batch, hidden_size)
C_n = torch.stack([s[1] for s in states])
return output, (h_n, C_n)
# 测试
lstm = LSTM(input_size=10, hidden_size=20, num_layers=2)
x = torch.randn(2, 15, 10)
output, (h_n, C_n) = lstm(x)
print(f"Output: {output.shape}") # (2, 15, 20)
print(f"h_n: {h_n.shape}") # (2, 2, 20)
print(f"C_n: {C_n.shape}") # (2, 2, 20)9.3 使用 PyTorch 内置 LSTM
import torch
import torch.nn as nn
class TextClassifier(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers=2, dropout=0.5):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=n_layers,
dropout=dropout if n_layers > 1 else 0,
batch_first=True,
bidirectional=True # 双向 LSTM
)
self.fc = nn.Linear(hidden_dim * 2, output_dim) # *2 因为双向
self.dropout = nn.Dropout(dropout)
def forward(self, text, text_lengths):
"""
text: (batch, seq_len)
text_lengths: (batch,) 每个样本的实际长度
"""
# 嵌入
embedded = self.dropout(self.embedding(text)) # (batch, seq_len, emb_dim)
# 打包序列(跳过填充)
packed = nn.utils.rnn.pack_padded_sequence(
embedded, text_lengths.cpu(), batch_first=True, enforce_sorted=False
)
# LSTM
packed_output, (h_n, c_n) = self.lstm(packed)
# 解包
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
# 使用最后的隐状态(双向拼接)
h_fwd = h_n[-2] # 前向最后一层
h_bwd = h_n[-1] # 后向最后一层
hidden = torch.cat([h_fwd, h_bwd], dim=1) # (batch, hidden_dim*2)
# 分类
output = self.fc(self.dropout(hidden)) # (batch, output_dim)
return output
# 实例化
model = TextClassifier(
vocab_size=10000,
embedding_dim=100,
hidden_dim=256,
output_dim=2, # 二分类
n_layers=2,
dropout=0.5
)
# 示例输入
text = torch.randint(0, 10000, (4, 20)) # (batch=4, seq_len=20)
lengths = torch.tensor([20, 15, 10, 18]) # 实际长度
output = model(text, lengths)
print(f"Predictions: {output.shape}") # (4, 2)9.4 Seq2Seq 机器翻译
class Encoder(nn.Module):
def __init__(self, vocab_size, emb_dim, hidden_dim, n_layers, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, emb_dim)
self.rnn = nn.LSTM(emb_dim, hidden_dim, n_layers, dropout=dropout, batch_first=True)
self.dropout = nn.Dropout(dropout)
def forward(self, src):
embedded = self.dropout(self.embedding(src))
outputs, (h_n, c_n) = self.rnn(embedded)
return h_n, c_n
class Decoder(nn.Module):
def __init__(self, vocab_size, emb_dim, hidden_dim, n_layers, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, emb_dim)
self.rnn = nn.LSTM(emb_dim, hidden_dim, n_layers, dropout=dropout, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, trg, hidden, cell):
embedded = self.dropout(self.embedding(trg))
output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
prediction = self.fc(output)
return prediction, hidden, cell
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, src, trg, teacher_forcing_ratio=0.5):
"""
src: (batch, src_len)
trg: (batch, trg_len)
"""
batch_size = src.size(0)
trg_len = trg.size(1)
trg_vocab_size = self.decoder.fc.out_features
# 存储输出
outputs = torch.zeros(batch_size, trg_len, trg_vocab_size, device=src.device)
# 编码
hidden, cell = self.encoder(src)
# 解码(第一步是 <sos> token)
input_token = trg[:, 0].unsqueeze(1)
for t in range(1, trg_len):
output, hidden, cell = self.decoder(input_token, hidden, cell)
outputs[:, t] = output.squeeze(1)
# Teacher forcing
use_teacher = torch.rand(1).item() < teacher_forcing_ratio
top1 = output.argmax(2)
input_token = trg[:, t].unsqueeze(1) if use_teacher else top1
return outputs
# 实例化
encoder = Encoder(vocab_size=5000, emb_dim=256, hidden_dim=512, n_layers=2, dropout=0.5)
decoder = Decoder(vocab_size=5000, emb_dim=256, hidden_dim=512, n_layers=2, dropout=0.5)
model = Seq2Seq(encoder, decoder)
# 示例
src = torch.randint(0, 5000, (4, 10)) # (batch, src_len)
trg = torch.randint(0, 5000, (4, 12)) # (batch, trg_len)
output = model(src, trg)
print(f"Output: {output.shape}") # (4, 12, 5000)9.5 训练循环示例
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
def train_epoch(model, dataloader, optimizer, criterion, clip):
model.train()
epoch_loss = 0
for batch in dataloader:
src, trg = batch
src, trg = src.cuda(), trg.cuda()
optimizer.zero_grad()
# 前向传播
output = model(src, trg)
# 计算损失(忽略 <pad> token)
output_dim = output.shape[-1]
output = output[:, 1:].reshape(-1, output_dim)
trg = trg[:, 1:].reshape(-1)
loss = criterion(output, trg)
# 反向传播
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(dataloader)
def evaluate(model, dataloader, criterion):
model.eval()
epoch_loss = 0
with torch.no_grad():
for batch in dataloader:
src, trg = batch
src, trg = src.cuda(), trg.cuda()
output = model(src, trg, teacher_forcing_ratio=0) # 验证时不用 teacher forcing
output_dim = output.shape[-1]
output = output[:, 1:].reshape(-1, output_dim)
trg = trg[:, 1:].reshape(-1)
loss = criterion(output, trg)
epoch_loss += loss.item()
return epoch_loss / len(dataloader)
# 主训练循环
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX) # 忽略填充
N_EPOCHS = 10
CLIP = 1.0
for epoch in range(N_EPOCHS):
train_loss = train_epoch(model, train_loader, optimizer, criterion, CLIP)
valid_loss = evaluate(model, valid_loader, criterion)
print(f"Epoch {epoch+1}")
print(f" Train Loss: {train_loss:.3f} | Train PPL: {torch.exp(torch.tensor(train_loss)):.3f}")
print(f" Valid Loss: {valid_loss:.3f} | Valid PPL: {torch.exp(torch.tensor(valid_loss)):.3f}")10. 总结
10.1 核心要点回顾
RNN 家族演进:
Vanilla RNN (梯度问题)
↓
LSTM (门控机制解决梯度消失)
↓
GRU (简化 LSTM)
↓
Attention RNN (增强长距离依赖)
↓
Transformer (完全取代 RNN)
↓
SSM/Mamba (RNN 复兴,线性复杂度)
关键概念:
- 序列建模: 处理时间依赖的自然选择
- 梯度问题: RNN 的核心挑战
- 门控机制: LSTM/GRU 的解决方案
- 并行化限制: 导致 Transformer 崛起
- 复兴趋势: SSM 提供新可能
10.2 实践建议
2025 年如何选择:
| 场景 | 推荐方案 | 理由 |
|---|---|---|
| 新 NLP 项目 | Transformer | 性能最优,生态成熟 |
| 实时流式处理 | LSTM/GRU | O(1) 推理,无需完整序列 |
| 时间序列预测 | LSTM + TCN | 简单有效,可解释性强 |
| 超长序列(100K+) | Mamba / RetNet | 线性复杂度 |
| 小数据集 | LSTM + 数据增强 | 归纳偏置强 |
| 边缘设备 | GRU 量化版本 | 内存高效 |
| 科研探索 | SSM / RWKV | 前沿方向 |
10.3 学习路径
初学者:
- 理解 Vanilla RNN(掌握循环思想)
- 学习 LSTM(门控机制)
- 实现简单项目(情感分析、时间序列)
进阶:
- 深入 BPTT 和梯度问题
- Seq2Seq + Attention
- 对比 RNN 与 Transformer
前沿:
- State Space Models (S4, Mamba)
- 混合架构设计
- 长上下文建模
11. 扩展阅读
11.1 必读论文
经典:
- LSTM: “Long Short-Term Memory” (Hochreiter, 1997)
- GRU: “Learning Phrase Representations using RNN Encoder-Decoder” (Cho, 2014)
- Seq2Seq: “Sequence to Sequence Learning with Neural Networks” (Sutskever, 2014)
- Attention: “Neural Machine Translation by Jointly Learning to Align and Translate” (Bahdanau, 2015)
现代:
- S4: “Efficiently Modeling Long Sequences with Structured State Spaces” (Gu, 2022)
- Mamba: “Mamba: Linear-Time Sequence Modeling with Selective State Spaces” (Gu, 2023) 🔥
- RWKV: “RWKV: Reinventing RNNs for the Transformer Era” (Peng, 2023) 🔥
- RetNet: “Retentive Network: A Successor to Transformer for Large Language Models” (Sun, 2023) 🔥
11.2 开源资源
教程:
- The Unreasonable Effectiveness of RNNs - Andrej Karpathy
- Understanding LSTM Networks - Christopher Olah
代码库:
前沿模型:
12. 相关笔记
- 神经网络 - 总览
- 神经网络层类型详解 - RNN 层详解
- Transformer - RNN 的继任者
- 卷积神经网络 - 计算机视觉架构
- 图神经网络 - 图结构数据(待创建)
- 注意力机制 - Attention 详解
- 序列到序列模型 - Seq2Seq 架构
最后更新: 2025-01-07
维护者: sean2077