ops4_AdamW优化器的CUDA实现
ops(4):AdamW 优化器的 CUDA 实现
作者: 紫气东来
发布时间: 2024-05-02
原文链接: https://zhuanlan.zhihu.com/p/695611950
文章概述
AdamW 是深度学习领域最常用的优化器,它不仅决定了模型的训练效果,还占据了训练过程中约 75% 的显存,是关键的核心算子。本文深入讨论 AdamW 的工作原理,并展示其 CUDA 实现。
一、AdamW 算法原理
1.1 Momentum(动量)的意义
基本梯度下降的问题
最基本的梯度更新公式:
θ_t = θ_{t-1} - η * g_t
其中:
- θ 是参数
- η 是学习率
- g_t 是梯度
这种方法存在以下问题:
- 收敛速度慢:固定学习率导致开始阶段收敛缓慢
- 震荡问题:可能越过最小值点或在最小值点附近震荡
- 鞍点陷阱:容易陷入鞍点
Adam 算法公式
Adam 通过引入一阶动量和二阶动量解决上述问题:
m_t ← β₁ * m_{t-1} + (1 - β₁) * g_t # 一阶动量更新
v_t ← β₂ * v_{t-1} + (1 - β₂) * g_t² # 二阶动量更新
m̂_t ← m_t / (1 - β₁^t) # 一阶动量偏差修正
v̂_t ← v_t / (1 - β₂^t) # 二阶动量偏差修正
θ_t ← θ_{t-1} - η_t * (m̂_t / (√v̂_t + ε)) # 参数更新
关键组件解析
一阶动量 (m_t)
- 代表惯性,考虑历史梯度的指数移动平均
- 帮助加速收敛,减少震荡
- 使优化过程更加平滑
二阶动量 (v_t)
- 用于自适应学习率控制
- 在分母位置,物理意义:
- 对于频繁更新的参数:降低学习率,避免被单个样本过度影响
- 对于稀疏更新的参数:提高学习率,从少量样本中多学习
偏差修正 (Bias Correction)
- 修正移动平均初始阶段的偏差
- 当 t 足够大时,m̂_t ≈ m_t
- 示例:当 t=1, β₁=0.9, m₀=0 时
- m₁ = 0.9 × 0 + 0.1 × g₁ = 0.1g₁
- 修正后:m̂₁ = 0.1g₁ / (1 - 0.9) = g₁
1.2 Weight Decay 与 L2 正则化的区别
在 SGD 中的等价性
对于 SGD 优化器,weight decay 与 L2 regularization 完全等价:
- L2 正则化项的梯度 = 权重衰减的参数更新
- 当 λ' = λ/α 时,两者参数更新完全相同
| 方法 | 损失函数 | 梯度 | 参数更新 |
|---|---|---|---|
| L2 正则化 | L + λ‖θ‖² | ∇L + 2λθ | θ - α(∇L + 2λθ) |
| Weight Decay | L | ∇L | θ - α∇L - αλ'θ |
在 Adam 中的差异
当引入 Momentum 后,等价性被打破:
L2 正则化(红色路径)
损失函数: L + λ‖θ‖²
梯度: g_t = ∇L + 2λθ
通过一阶和二阶动量处理后,方向变得模糊
Weight Decay(绿色路径)
直接在参数更新时应用衰减:
θ_t ← θ_{t-1} - η * (m̂_t / (√v̂_t + ε) + λ * θ_{t-1})
关键差异:
- L2 正则化:衰减项经过动量处理,效果被稀释
- Weight Decay:直接作用于参数,保持原始意图(让参数变小)
实验结果证明,AdamW(使用 weight decay)的效果显著优于 Adam + L2 正则化。
二、AdamW 的 CUDA 实现
2.1 优化器状态组成
AdamW 优化器需要维护以下状态:
- params: 模型参数
- grads: 梯度
- momentum (m): 一阶动量
- variance (v): 二阶动量
2.2 CPU 参考实现
// CPU code reference
void adamw_cpu(float* params_memory, const float* grads_memory,
float* m_memory, float* v_memory, int t,
long num_parameters, float learning_rate=1e-3,
float beta1=0.9, float beta2=0.999,
float eps=1e-8, float weight_decay=0.0) {
for (int i = 0; i < num_parameters; i++) {
float param = params_memory[i];
float grad = grads_memory[i];
// 更新一阶动量 (momentum)
float m = beta1 * m_memory[i] + (1.0f - beta1) * grad;
// 更新二阶动量 (RMSprop)
float v = beta2 * v_memory[i] + (1.0f - beta2) * grad * grad;
// 偏差修正
float m_hat = m / (1.0f - powf(beta1, t));
float v_hat = v / (1.0f - powf(beta2, t));
// 更新参数(包含 weight decay)
m_memory[i] = m;
v_memory[i] = v;
params_memory[i] -= learning_rate * (m_hat / (sqrtf(v_hat) + eps)
+ weight_decay * param);
}
}
实现要点:
- 顺序遍历所有参数
- 每个参数独立更新动量和方差
- 应用偏差修正
- 最后更新参数值
2.3 CUDA 基础实现
// naive fused kernel
__global__ void adamw_kernel1(float* params_memory,
const float* grads_memory,
float* m_memory, float* v_memory,
long num_parameters,
float learning_rate, float beta1, float beta2,
float beta1_correction, float beta2_correction,
float eps, float weight_decay) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= num_parameters) return; // 边界检查
// 更新一阶动量 (momentum)
m_memory[i] = beta1 * m_memory[i] + (1.0f - beta1) * grads_memory[i];
// 更新二阶动量 (RMSprop)
v_memory[i] = beta2 * v_memory[i] + (1.0f - beta2) * grads_memory[i] * grads_memory[i];
// 偏差修正
float m_hat = m_memory[i] / beta1_correction;
float v_hat = v_memory[i] / beta2_correction;
// 参数更新
params_memory[i] -= learning_rate * (m_hat / (sqrtf(v_hat) + eps)
+ weight_decay * params_memory[i]);
}
CUDA 实现特点:
- 并行化:每个线程处理一个参数
- 融合操作:将所有更新步骤融合在一个 kernel 中
- 预计算修正因子:beta1_correction = 1 - β₁^t,避免重复计算 pow()
2.4 性能对比
time gpu: 0.0409 ms
time cpu: 0.0612 ms
GPU 相比 CPU 加速约 1.5x,这只是基础实现的性能。
三、CUDA 优化技术分析
虽然原文未详细展开优化部分,但基于 CUDA 优化的通用原则,AdamW 可以从以下方面优化:
3.1 内存访问优化
问题:基础实现中存在多次全局内存访问
- 读取:params, grads, m, v
- 写入:params, m, v
优化方向:
- 合并访问:确保线程访问连续内存
- 向量化加载:使用 float4 一次加载 4 个元素
- 共享内存:对于需要多次访问的数据使用共享内存缓存
3.2 计算优化
避免昂贵操作:
- 预计算
1 - beta1,1 - beta2 - 预计算偏差修正因子
- 使用
rsqrtf()代替1.0f / sqrtf()
指令级优化:
- 利用 FMA (Fused Multiply-Add) 指令
- 减少寄存器使用,提高占用率
3.3 融合优化
Kernel 融合:
- 将梯度计算和优化器更新融合
- 减少中间结果的内存读写
多参数组融合:
- 对于小参数张量,批量处理多个参数组
3.4 数值稳定性
防止数值溢出:
// 使用 epsilon 防止除零
float denom = sqrtf(v_hat) + eps;
// 梯度裁剪
grad = fminf(fmaxf(grad, -grad_clip), grad_clip);
3.5 混合精度优化
FP16/BF16 计算:
- 动量和方差使用 FP32 保持精度
- 梯度和参数可使用 FP16 节省内存
- 关键计算使用 FP32 累加
四、内存占用分析
4.1 显存占用构成
对于参数量为 N 的模型:
| 组件 | 数据类型 | 大小 |
|---|---|---|
| 参数 (params) | FP32 | 4N bytes |
| 梯度 (grads) | FP32 | 4N bytes |
| 一阶动量 (m) | FP32 | 4N bytes |
| 二阶动量 (v) | FP32 | 4N bytes |
| 总计 | 16N bytes |
示例:GPT-2 (124M 参数)
- 参数:124M × 4 = 496 MB
- 优化器状态:124M × 12 = 1488 MB
- 总计:≈ 2 GB
优化器状态占总显存的 75%,这就是为什么 AdamW 是显存消耗的主要来源。
4.2 内存优化策略
状态分片:
- ZeRO 优化器:将优化器状态分片到多个 GPU
- 每个 GPU 只保存部分参数的完整状态
混合精度:
- 使用 FP16 存储梯度:节省 50% 梯度内存
- 保持 FP32 动量:确保数值稳定性
CPU Offloading:
- 将优化器状态卸载到 CPU 内存
- 更新时传输到 GPU,更新后传回
五、实现要点总结
5.1 算法层面
-
Weight Decay 的正确实现
- 直接作用于参数更新,不经过动量
- 与 L2 正则化在 Adam 中不等价
-
偏差修正的必要性
- 训练初期(小 t 值)尤为重要
- 可以预计算避免重复 pow() 调用
-
数值稳定性
- epsilon 防止除零
- 梯度裁剪防止爆炸
5.2 CUDA 实现层面
-
Kernel 设计
- 每线程一参数的简单映射
- 融合所有更新步骤减少内存访问
-
内存访问模式
- 确保合并访问
- 考虑向量化加载
-
性能优化
- 预计算常量
- 使用快速数学函数
- 考虑混合精度
5.3 工程实践
-
显存管理
- 优化器状态是显存大头(75%)
- 考虑状态分片或 CPU offloading
-
可扩展性
- 支持参数分组(不同学习率)
- 支持梯度累积
-
调试和验证
- 与 PyTorch 实现对比
- 检查数值精度
- 性能 profiling
六、参考资料
- Karpathy's llm.c - AdamW CUDA Implementation
- AdamW优化器简单理解 - CSDN博客
- Understanding Weight Decay vs L2 Regularization
- Decoupled Weight Decay Regularization (AdamW Paper)
七、扩展阅读
相关优化器
- Adam: AdamW 的前身,使用 L2 正则化
- AdaGrad: 自适应学习率的早期尝试
- RMSprop: 二阶动量的来源
- Lion: 最新的符号动量优化器
进阶优化
- 8-bit Adam: 量化优化器状态
- Adafactor: 分解二阶动量矩阵
- LAMB: 大批量训练的层级自适应优化器
分布式训练
- ZeRO: 优化器状态分片
- FSDP: 全分片数据并行
- Tensor Parallelism: 张量级并行
燕子不归春事晚,一汀烟雨杏花寒。 —— 戴叔伦《苏溪亭》