Transformer 架构详解

核心定位: 完全基于注意力机制的革命性神经网络架构,现代大语言模型的基础。 原始论文: Vaswani et al. (2017) “Attention is All You Need” [NeurIPS 2017]


1 架构概览

Transformer 是 2017 年提出的纯神经网络架构,完全抛弃了传统的 RNN/CNN,仅使用注意力机制处理序列数据。

1.1 核心创新

  • 并行计算: 所有位置同时处理,不像 RNN 串行递归
  • 全局依赖: 注意力机制直接连接任意距离的位置
  • 可扩展性: 从 100M 到 1000B+ 参数,架构基本不变
  • 迁移学习: 预训练模型可在多种下游任务微调

1.2 标准架构(Encoder-Decoder)

输入 → Embedding + Positional Encoding
     ↓
[Encoder × N 层]
- Multi-Head Self-Attention
- Add & Norm
- Feed-Forward Network
- Add & Norm
     ↓
[Decoder × N 层]
- Masked Multi-Head Self-Attention
- Add & Norm
- Cross-Attention (与 Encoder 输出)
- Add & Norm
- Feed-Forward Network
- Add & Norm
     ↓
Linear + Softmax → 输出概率

2 核心组件详解

2.1 多头自注意力(Multi-Head Self-Attention)

数学定义:

单头注意力:

多头注意力:

其中:

参数:

  • - Query 投影矩阵
  • - Key 投影矩阵
  • - Value 投影矩阵
  • - 输出投影矩阵

典型配置:

  • GPT-3: 头, ,
  • LLaMA 7B: 头, ,

计算复杂度:

  • 时间复杂度: 是序列长度)
  • 空间复杂度:(存储注意力矩阵)

优化技术:

  • FlashAttention: IO-aware 算法,显存节省 20×
  • 稀疏注意力: 仅计算部分位置(Sparse Transformer, Longformer)
  • 线性注意力: 近似计算,降低到 (Performer, Linear Transformer)

2.2 位置编码(Positional Encoding)

Transformer 本身无法区分位置顺序,需要显式注入位置信息。

2.2.1 正弦位置编码(原始论文)

优点: 无需训练,可外推到更长序列 缺点: 表达能力有限

2.2.2 可学习位置嵌入(BERT)

position_embeddings = nn.Embedding(max_position, d_model)

优点: 更强的表达能力 缺点: 不能外推到训练时未见过的长度

2.2.3 旋转位置编码 RoPE(LLaMA)

通过旋转矩阵注入相对位置信息:

优点:

  • 编码相对位置(更符合语言特性)
  • 更好的外推性能
  • 当前大模型主流选择(LLaMA, Mistral, Qwen)

2.2.4 ALiBi(Attention with Linear Biases)

在注意力分数上添加线性偏置:

优点: 极简设计,外推性能优秀(BLOOM, MPT)

2.3 前馈神经网络(FFN)

标准结构:

参数:

  • - 通常

变体:GLU 系列(性能更优)

SwiGLU(LLaMA):

其中 Swish

GeGLU(PaLM):

参数量对比:

  • 标准 FFN:
  • GLU 变体: (多 50%,但性能提升明显)

2.4 残差连接与层归一化

标准结构:

归一化位置对比:

类型结构代表模型优势
Post-LN(原始)BERT训练稳定
Pre-LNGPT-2, GPT-3更易训练深层
DeepNormLLaMA超深模型稳定

RMSNorm(简化版 LayerNorm):

  • 去掉均值中心化(只保留方差归一化)
  • 计算更快,效果相当(LLaMA, T5 采用)

3 三大变体架构

3.1 Encoder-only(理解任务)

代表模型: BERT, RoBERTa, DeBERTa

结构特点:

  • 双向注意力(可看到上下文)
  • 适合分类、实体识别、问答

典型应用:

  • 文本分类(情感分析)
  • 命名实体识别(NER)
  • 问答系统(SQuAD)
  • 语义相似度计算

3.2 Decoder-only(生成任务)

代表模型: GPT 系列, LLaMA, Mistral, Qwen

结构特点:

  • 单向注意力(因果掩码,只能看左边)
  • 自回归生成(逐 token 预测)
  • 当前大语言模型主流选择

核心机制:Causal Masking

mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
# 上三角矩阵,防止看到未来信息

典型应用:

  • 文本生成
  • 对话系统
  • 代码生成
  • 指令遵循

3.3 Encoder-Decoder(序列到序列)

代表模型: T5, BART, mT5

结构特点:

  • Encoder 双向 + Decoder 单向
  • Cross-Attention 连接两者
  • 适合输入输出长度差异大的任务

典型应用:

  • 机器翻译
  • 文本摘要
  • 语法纠错
  • 问答生成

4 训练策略

4.1 预训练任务

任务类型目标代表模型公式
Masked Language Modeling预测被掩码词BERT
Causal Language Modeling预测下一个词GPT
Prefix Language Modeling预测后缀UniLM-
Span Corruption预测连续片段T5-
Denoising Autoencoding恢复原始文本BART-

4.2 优化技巧

学习率调度:

# Warmup + Cosine Decay(主流)
lr = peak_lr * min(step / warmup_steps, 
                   0.5 * (1 + cos(π * (step - warmup) / total_steps)))

梯度裁剪:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

混合精度训练:

  • BF16 优先(比 FP16 更稳定,避免溢出)
  • 关键参数(LayerNorm, Softmax)保持 FP32

显存优化:

  • Activation Checkpointing(重计算中间激活)
  • FlashAttention(IO-aware 注意力)
  • ZeRO(分布式显存优化)

5 推理优化

5.1 KV Cache

原理: 缓存已计算的 Key 和 Value,避免重复计算

节省计算:

  • 无 cache: 计算量
  • 有 cache: 计算量(生成阶段)

显存占用:

示例(LLaMA 7B):

  • 每个 token 的 KV:
  • 2048 上下文: (单个请求)

5.2 Paged KV Cache(vLLM)

核心思想: 将 KV cache 分页管理(类似操作系统虚拟内存)

优势:

  • 消除内存碎片
  • 支持共享 prefix(多请求共用系统提示)
  • 提高 GPU 利用率 ~3×

5.3 Continuous Batching

传统批处理问题:

  • 批次中最慢的请求决定整体延迟
  • GPU 利用率低

改进方案:

  • 动态插入新请求(一旦有空位)
  • 完成的请求立即移除
  • 吞吐提升 ~10×

6 长上下文扩展

6.1 问题根源

标准 Transformer 的 复杂度在长序列下不可承受:

  • 32K 上下文需要 1024× 于 1K 上下文的计算量
  • 显存占用呈平方增长

6.2 主流解决方案

方法原理代表模型上下文长度
Sparse Attention仅计算部分注意力Longformer4K-16K
Sliding Window局部注意力 + 全局 tokenMistral32K
ALiBi 外推线性偏置位置编码BLOOM, MPT可外推
RoPE 插值缩放旋转频率LLaMA 232K
YaRN改进 RoPE 插值Nous-Hermes64K-128K
FlashAttention-2IO 优化通用无上限
Ring Attention分布式长上下文研究中百万级

7 Transformer vs 其他架构

7.1 与 RNN 对比

维度RNN/LSTMTransformer
并行性❌ 串行递归✅ 完全并行
长距离依赖❌ 梯度消失✅ 直接全局连接
训练速度快(GPU 友好)
推理速度快(逐步计算)慢(需完整前向)
内存占用高(KV cache)
外推能力一般(需特殊位置编码)

7.2 与 CNN 对比

维度CNNTransformer
感受野局部(可堆叠)全局(单层)
参数共享✅ 卷积核共享❌ 每个位置独立
归纳偏置强(局部性)弱(数据驱动)
图像任务✅ 传统优势✅ ViT 追平/超越(大规模数据)
文本任务❌ 不适合✅ 标准选择

7.3 与 State Space Models(SSM/Mamba)

维度TransformerSSM/Mamba
时间复杂度
空间复杂度
并行训练✅ 完全并行✅ 完全并行
推理效率慢(需完整前向)快(递归更新)
长上下文需优化原生支持
生态成熟度✅ 极其成熟⚠️ 尚在发展

8 典型模型配置

8.1 GPT-3(175B)

架构: Decoder-only
层数: 96
隐藏维度: 12288
注意力头数: 96
FFN 维度: 49152 (4×)
词表大小: 50257
最大序列长度: 2048
位置编码: 可学习
激活函数: GELU

8.2 LLaMA 2(7B)

架构: Decoder-only
层数: 32
隐藏维度: 4096
注意力头数: 32
FFN 维度: 11008 (~2.7×)
GQA: 否(7B 使用 MHA)
词表大小: 32000
最大序列长度: 4096
位置编码: RoPE
激活函数: SiLU
归一化: RMSNorm (Pre-LN)

8.3 LLaMA 2(70B)

架构: Decoder-only
层数: 80
隐藏维度: 8192
注意力头数: 64
KV 头数: 8 (GQA)
FFN 维度: 28672
词表大小: 32000
最大序列长度: 4096
位置编码: RoPE
激活函数: SiLU
归一化: RMSNorm (Pre-LN)

GQA (Grouped Query Attention):

  • 减少 KV cache 大小
  • 保持性能的同时降低推理显存
  • 70B 使用 8 个 KV 头共享给 64 个 Q 头

9 实现代码示例

9.1 最小 Transformer Block(PyTorch)

import torch
import torch.nn as nn
 
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        # Multi-Head Attention
        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        
        # Feed-Forward Network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x, mask=None):
        # Self-Attention + Residual + Norm
        attn_out, _ = self.attention(x, x, x, attn_mask=mask)
        x = self.norm1(x + attn_out)
        
        # FFN + Residual + Norm
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        
        return x

9.2 RoPE 位置编码实现

def apply_rotary_emb(q, k, cos, sin):
    """应用旋转位置编码
    
    Args:
        q, k: [batch, seq_len, n_heads, head_dim]
        cos, sin: [seq_len, head_dim]
    """
    # 分离偶数和奇数维度
    q_rot = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1).flatten(-2)
    k_rot = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1).flatten(-2)
    
    # 应用旋转
    q_embed = q * cos + q_rot * sin
    k_embed = k * cos + k_rot * sin
    
    return q_embed, k_embed

9.3 Causal Mask 生成

def create_causal_mask(seq_len, device):
    """生成因果注意力掩码(Decoder-only)"""
    mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask

10 性能优化 Checklist

10.1 训练阶段

  • 使用 FlashAttention-2(显存 ↓20×,速度 ↑2×)
  • 启用 BF16 混合精度(避免 FP16 溢出)
  • Gradient Checkpointing(显存 ↓50%,速度 ↓20%)
  • 梯度累积(模拟大批次)
  • ZeRO Stage 2/3(分布式显存优化)
  • 数据并行 + 张量并行(多 GPU)

10.2 推理阶段

  • KV Cache(必选,速度 ↑10×)
  • Paged Attention(vLLM,显存 ↑2-3×)
  • Continuous Batching(吞吐 ↑10×)
  • 量化 INT8/FP8(推理速度 ↑2×)
  • Speculative Decoding(草稿模型加速)
  • 批量推理(提高 GPU 利用率)

11 常见问题与调试

11.1 训练不稳定

症状: Loss 突然飙升为 NaN

原因与解决:

  1. FP16 溢出 → 改用 BF16
  2. 学习率过大 → 降低 peak_lr 或增加 warmup
  3. 梯度爆炸 → 启用 gradient clipping (max_norm=1.0)
  4. 初始化错误 → 使用 scaled initialization

11.2 显存不足

解决方案(按效果排序):

  1. Gradient Checkpointing - 节省 50% 显存,速度降低 20%
  2. 降低批次大小 + 梯度累积
  3. FlashAttention - 节省 20× 注意力显存
  4. Mixed Precision - BF16/FP16
  5. ZeRO Stage 3 - 分布式显存优化
  6. 降低序列长度 - 分段处理

11.3 推理速度慢

诊断步骤:

  1. 检查是否启用 KV Cache
  2. 使用 FlashAttention 替代标准注意力
  3. 考虑量化(INT8/FP8)
  4. 批量推理(提高 GPU 利用率)
  5. 使用 vLLM/TensorRT-LLM 等推理引擎

12 扩展阅读

12.1 必读论文

基础:

优化:

架构演进:

12.2 开源实现

训练框架:

推理引擎:


13 相关笔记链接


最后更新: 2025-01-07 维护者: sean2077