Transformer 架构详解
核心定位: 完全基于注意力机制的革命性神经网络架构,现代大语言模型的基础。 原始论文: Vaswani et al. (2017) “Attention is All You Need” [NeurIPS 2017]
1 架构概览
Transformer 是 2017 年提出的纯神经网络架构,完全抛弃了传统的 RNN/CNN,仅使用注意力机制处理序列数据。
1.1 核心创新
- ✅ 并行计算: 所有位置同时处理,不像 RNN 串行递归
- ✅ 全局依赖: 注意力机制直接连接任意距离的位置
- ✅ 可扩展性: 从 100M 到 1000B+ 参数,架构基本不变
- ✅ 迁移学习: 预训练模型可在多种下游任务微调
1.2 标准架构(Encoder-Decoder)
输入 → Embedding + Positional Encoding
↓
[Encoder × N 层]
- Multi-Head Self-Attention
- Add & Norm
- Feed-Forward Network
- Add & Norm
↓
[Decoder × N 层]
- Masked Multi-Head Self-Attention
- Add & Norm
- Cross-Attention (与 Encoder 输出)
- Add & Norm
- Feed-Forward Network
- Add & Norm
↓
Linear + Softmax → 输出概率
2 核心组件详解
2.1 多头自注意力(Multi-Head Self-Attention)
数学定义:
单头注意力:
多头注意力:
其中:
参数:
- - Query 投影矩阵
- - Key 投影矩阵
- - Value 投影矩阵
- - 输出投影矩阵
典型配置:
- GPT-3: 头, ,
- LLaMA 7B: 头, ,
计算复杂度:
- 时间复杂度:( 是序列长度)
- 空间复杂度:(存储注意力矩阵)
优化技术:
- FlashAttention: IO-aware 算法,显存节省 20×
- 稀疏注意力: 仅计算部分位置(Sparse Transformer, Longformer)
- 线性注意力: 近似计算,降低到 (Performer, Linear Transformer)
2.2 位置编码(Positional Encoding)
Transformer 本身无法区分位置顺序,需要显式注入位置信息。
2.2.1 正弦位置编码(原始论文)
优点: 无需训练,可外推到更长序列 缺点: 表达能力有限
2.2.2 可学习位置嵌入(BERT)
position_embeddings = nn.Embedding(max_position, d_model)优点: 更强的表达能力 缺点: 不能外推到训练时未见过的长度
2.2.3 旋转位置编码 RoPE(LLaMA)
通过旋转矩阵注入相对位置信息:
优点:
- 编码相对位置(更符合语言特性)
- 更好的外推性能
- 当前大模型主流选择(LLaMA, Mistral, Qwen)
2.2.4 ALiBi(Attention with Linear Biases)
在注意力分数上添加线性偏置:
优点: 极简设计,外推性能优秀(BLOOM, MPT)
2.3 前馈神经网络(FFN)
标准结构:
参数:
- - 通常
变体:GLU 系列(性能更优)
SwiGLU(LLaMA):
其中 Swish
GeGLU(PaLM):
参数量对比:
- 标准 FFN:
- GLU 变体: (多 50%,但性能提升明显)
2.4 残差连接与层归一化
标准结构:
归一化位置对比:
| 类型 | 结构 | 代表模型 | 优势 |
|---|---|---|---|
| Post-LN(原始) | BERT | 训练稳定 | |
| Pre-LN | GPT-2, GPT-3 | 更易训练深层 | |
| DeepNorm | LLaMA | 超深模型稳定 |
RMSNorm(简化版 LayerNorm):
- 去掉均值中心化(只保留方差归一化)
- 计算更快,效果相当(LLaMA, T5 采用)
3 三大变体架构
3.1 Encoder-only(理解任务)
代表模型: BERT, RoBERTa, DeBERTa
结构特点:
- 双向注意力(可看到上下文)
- 适合分类、实体识别、问答
典型应用:
- 文本分类(情感分析)
- 命名实体识别(NER)
- 问答系统(SQuAD)
- 语义相似度计算
3.2 Decoder-only(生成任务)
代表模型: GPT 系列, LLaMA, Mistral, Qwen
结构特点:
- 单向注意力(因果掩码,只能看左边)
- 自回归生成(逐 token 预测)
- 当前大语言模型主流选择
核心机制:Causal Masking
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
# 上三角矩阵,防止看到未来信息典型应用:
- 文本生成
- 对话系统
- 代码生成
- 指令遵循
3.3 Encoder-Decoder(序列到序列)
代表模型: T5, BART, mT5
结构特点:
- Encoder 双向 + Decoder 单向
- Cross-Attention 连接两者
- 适合输入输出长度差异大的任务
典型应用:
- 机器翻译
- 文本摘要
- 语法纠错
- 问答生成
4 训练策略
4.1 预训练任务
| 任务类型 | 目标 | 代表模型 | 公式 |
|---|---|---|---|
| Masked Language Modeling | 预测被掩码词 | BERT | |
| Causal Language Modeling | 预测下一个词 | GPT | |
| Prefix Language Modeling | 预测后缀 | UniLM | - |
| Span Corruption | 预测连续片段 | T5 | - |
| Denoising Autoencoding | 恢复原始文本 | BART | - |
4.2 优化技巧
学习率调度:
# Warmup + Cosine Decay(主流)
lr = peak_lr * min(step / warmup_steps,
0.5 * (1 + cos(π * (step - warmup) / total_steps)))梯度裁剪:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)混合精度训练:
- BF16 优先(比 FP16 更稳定,避免溢出)
- 关键参数(LayerNorm, Softmax)保持 FP32
显存优化:
- Activation Checkpointing(重计算中间激活)
- FlashAttention(IO-aware 注意力)
- ZeRO(分布式显存优化)
5 推理优化
5.1 KV Cache
原理: 缓存已计算的 Key 和 Value,避免重复计算
节省计算:
- 无 cache: 计算量
- 有 cache: 计算量(生成阶段)
显存占用:
示例(LLaMA 7B):
- 每个 token 的 KV:
- 2048 上下文: (单个请求)
5.2 Paged KV Cache(vLLM)
核心思想: 将 KV cache 分页管理(类似操作系统虚拟内存)
优势:
- 消除内存碎片
- 支持共享 prefix(多请求共用系统提示)
- 提高 GPU 利用率 ~3×
5.3 Continuous Batching
传统批处理问题:
- 批次中最慢的请求决定整体延迟
- GPU 利用率低
改进方案:
- 动态插入新请求(一旦有空位)
- 完成的请求立即移除
- 吞吐提升 ~10×
6 长上下文扩展
6.1 问题根源
标准 Transformer 的 复杂度在长序列下不可承受:
- 32K 上下文需要 1024× 于 1K 上下文的计算量
- 显存占用呈平方增长
6.2 主流解决方案
| 方法 | 原理 | 代表模型 | 上下文长度 |
|---|---|---|---|
| Sparse Attention | 仅计算部分注意力 | Longformer | 4K-16K |
| Sliding Window | 局部注意力 + 全局 token | Mistral | 32K |
| ALiBi 外推 | 线性偏置位置编码 | BLOOM, MPT | 可外推 |
| RoPE 插值 | 缩放旋转频率 | LLaMA 2 | 32K |
| YaRN | 改进 RoPE 插值 | Nous-Hermes | 64K-128K |
| FlashAttention-2 | IO 优化 | 通用 | 无上限 |
| Ring Attention | 分布式长上下文 | 研究中 | 百万级 |
7 Transformer vs 其他架构
7.1 与 RNN 对比
| 维度 | RNN/LSTM | Transformer |
|---|---|---|
| 并行性 | ❌ 串行递归 | ✅ 完全并行 |
| 长距离依赖 | ❌ 梯度消失 | ✅ 直接全局连接 |
| 训练速度 | 慢 | 快(GPU 友好) |
| 推理速度 | 快(逐步计算) | 慢(需完整前向) |
| 内存占用 | 低 | 高(KV cache) |
| 外推能力 | 好 | 一般(需特殊位置编码) |
7.2 与 CNN 对比
| 维度 | CNN | Transformer |
|---|---|---|
| 感受野 | 局部(可堆叠) | 全局(单层) |
| 参数共享 | ✅ 卷积核共享 | ❌ 每个位置独立 |
| 归纳偏置 | 强(局部性) | 弱(数据驱动) |
| 图像任务 | ✅ 传统优势 | ✅ ViT 追平/超越(大规模数据) |
| 文本任务 | ❌ 不适合 | ✅ 标准选择 |
7.3 与 State Space Models(SSM/Mamba)
| 维度 | Transformer | SSM/Mamba |
|---|---|---|
| 时间复杂度 | ||
| 空间复杂度 | ||
| 并行训练 | ✅ 完全并行 | ✅ 完全并行 |
| 推理效率 | 慢(需完整前向) | 快(递归更新) |
| 长上下文 | 需优化 | 原生支持 |
| 生态成熟度 | ✅ 极其成熟 | ⚠️ 尚在发展 |
8 典型模型配置
8.1 GPT-3(175B)
架构: Decoder-only
层数: 96
隐藏维度: 12288
注意力头数: 96
FFN 维度: 49152 (4×)
词表大小: 50257
最大序列长度: 2048
位置编码: 可学习
激活函数: GELU8.2 LLaMA 2(7B)
架构: Decoder-only
层数: 32
隐藏维度: 4096
注意力头数: 32
FFN 维度: 11008 (~2.7×)
GQA: 否(7B 使用 MHA)
词表大小: 32000
最大序列长度: 4096
位置编码: RoPE
激活函数: SiLU
归一化: RMSNorm (Pre-LN)8.3 LLaMA 2(70B)
架构: Decoder-only
层数: 80
隐藏维度: 8192
注意力头数: 64
KV 头数: 8 (GQA)
FFN 维度: 28672
词表大小: 32000
最大序列长度: 4096
位置编码: RoPE
激活函数: SiLU
归一化: RMSNorm (Pre-LN)GQA (Grouped Query Attention):
- 减少 KV cache 大小
- 保持性能的同时降低推理显存
- 70B 使用 8 个 KV 头共享给 64 个 Q 头
9 实现代码示例
9.1 最小 Transformer Block(PyTorch)
import torch
import torch.nn as nn
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
# Multi-Head Attention
self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
# Feed-Forward Network
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# Self-Attention + Residual + Norm
attn_out, _ = self.attention(x, x, x, attn_mask=mask)
x = self.norm1(x + attn_out)
# FFN + Residual + Norm
ffn_out = self.ffn(x)
x = self.norm2(x + ffn_out)
return x9.2 RoPE 位置编码实现
def apply_rotary_emb(q, k, cos, sin):
"""应用旋转位置编码
Args:
q, k: [batch, seq_len, n_heads, head_dim]
cos, sin: [seq_len, head_dim]
"""
# 分离偶数和奇数维度
q_rot = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1).flatten(-2)
k_rot = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1).flatten(-2)
# 应用旋转
q_embed = q * cos + q_rot * sin
k_embed = k * cos + k_rot * sin
return q_embed, k_embed9.3 Causal Mask 生成
def create_causal_mask(seq_len, device):
"""生成因果注意力掩码(Decoder-only)"""
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask10 性能优化 Checklist
10.1 训练阶段
- 使用 FlashAttention-2(显存 ↓20×,速度 ↑2×)
- 启用 BF16 混合精度(避免 FP16 溢出)
- Gradient Checkpointing(显存 ↓50%,速度 ↓20%)
- 梯度累积(模拟大批次)
- ZeRO Stage 2/3(分布式显存优化)
- 数据并行 + 张量并行(多 GPU)
10.2 推理阶段
- KV Cache(必选,速度 ↑10×)
- Paged Attention(vLLM,显存 ↑2-3×)
- Continuous Batching(吞吐 ↑10×)
- 量化 INT8/FP8(推理速度 ↑2×)
- Speculative Decoding(草稿模型加速)
- 批量推理(提高 GPU 利用率)
11 常见问题与调试
11.1 训练不稳定
症状: Loss 突然飙升为 NaN
原因与解决:
- FP16 溢出 → 改用 BF16
- 学习率过大 → 降低 peak_lr 或增加 warmup
- 梯度爆炸 → 启用 gradient clipping (max_norm=1.0)
- 初始化错误 → 使用 scaled initialization
11.2 显存不足
解决方案(按效果排序):
- Gradient Checkpointing - 节省 50% 显存,速度降低 20%
- 降低批次大小 + 梯度累积
- FlashAttention - 节省 20× 注意力显存
- Mixed Precision - BF16/FP16
- ZeRO Stage 3 - 分布式显存优化
- 降低序列长度 - 分段处理
11.3 推理速度慢
诊断步骤:
- 检查是否启用 KV Cache
- 使用 FlashAttention 替代标准注意力
- 考虑量化(INT8/FP8)
- 批量推理(提高 GPU 利用率)
- 使用 vLLM/TensorRT-LLM 等推理引擎
12 扩展阅读
12.1 必读论文
基础:
- Attention is All You Need (Vaswani et al., 2017)
- BERT: Pre-training of Deep Bidirectional Transformers (Devlin et al., 2018)
- Language Models are Unsupervised Multitask Learners (GPT-2, 2019)
优化:
- FlashAttention: Fast and Memory-Efficient Exact Attention (Dao et al., 2022)
- RoFormer: Enhanced Transformer with Rotary Position Embedding (Su et al., 2021)
- Train Short, Test Long: Attention with Linear Biases (ALiBi, 2021)
架构演进:
- LLaMA: Open and Efficient Foundation Language Models (Touvron et al., 2023)
- Mistral 7B (Jiang et al., 2023)
12.2 开源实现
训练框架:
- nanoGPT - 最小教学实现(Andrej Karpathy)
- Megatron-LM - NVIDIA 大规模训练
- DeepSpeed - 微软分布式训练
推理引擎:
- vLLM - PagedAttention + Continuous Batching
- TensorRT-LLM - NVIDIA 推理优化
- llama.cpp - CPU 高效推理
13 相关笔记链接
最后更新: 2025-01-07 维护者: sean2077