ops4_AdamW优化器的CUDA实现

ops(4):AdamW 优化器的 CUDA 实现

作者: 紫气东来
发布时间: 2024-05-02
原文链接: https://zhuanlan.zhihu.com/p/695611950


文章概述

AdamW 是深度学习领域最常用的优化器,它不仅决定了模型的训练效果,还占据了训练过程中约 75% 的显存,是关键的核心算子。本文深入讨论 AdamW 的工作原理,并展示其 CUDA 实现。


一、AdamW 算法原理

1.1 Momentum(动量)的意义

基本梯度下降的问题

最基本的梯度更新公式:

θ_t = θ_{t-1} - η * g_t

其中:

这种方法存在以下问题:

  1. 收敛速度慢:固定学习率导致开始阶段收敛缓慢
  2. 震荡问题:可能越过最小值点或在最小值点附近震荡
  3. 鞍点陷阱:容易陷入鞍点

Adam 算法公式

Adam 通过引入一阶动量和二阶动量解决上述问题:

m_t ← β₁ * m_{t-1} + (1 - β₁) * g_t          # 一阶动量更新
v_t ← β₂ * v_{t-1} + (1 - β₂) * g_t²         # 二阶动量更新
m̂_t ← m_t / (1 - β₁^t)                       # 一阶动量偏差修正
v̂_t ← v_t / (1 - β₂^t)                       # 二阶动量偏差修正
θ_t ← θ_{t-1} - η_t * (m̂_t / (√v̂_t + ε))   # 参数更新

关键组件解析

一阶动量 (m_t)

二阶动量 (v_t)

偏差修正 (Bias Correction)

1.2 Weight Decay 与 L2 正则化的区别

在 SGD 中的等价性

对于 SGD 优化器,weight decay 与 L2 regularization 完全等价:

方法 损失函数 梯度 参数更新
L2 正则化 L + λ‖θ‖² ∇L + 2λθ θ - α(∇L + 2λθ)
Weight Decay L ∇L θ - α∇L - αλ'θ

在 Adam 中的差异

当引入 Momentum 后,等价性被打破:

L2 正则化(红色路径)

损失函数: L + λ‖θ‖²
梯度: g_t = ∇L + 2λθ
通过一阶和二阶动量处理后,方向变得模糊

Weight Decay(绿色路径)

直接在参数更新时应用衰减:
θ_t ← θ_{t-1} - η * (m̂_t / (√v̂_t + ε) + λ * θ_{t-1})

关键差异

实验结果证明,AdamW(使用 weight decay)的效果显著优于 Adam + L2 正则化。


二、AdamW 的 CUDA 实现

2.1 优化器状态组成

AdamW 优化器需要维护以下状态:

  1. params: 模型参数
  2. grads: 梯度
  3. momentum (m): 一阶动量
  4. variance (v): 二阶动量

2.2 CPU 参考实现

// CPU code reference
void adamw_cpu(float* params_memory, const float* grads_memory, 
               float* m_memory, float* v_memory, int t, 
               long num_parameters, float learning_rate=1e-3, 
               float beta1=0.9, float beta2=0.999, 
               float eps=1e-8, float weight_decay=0.0) {
    
    for (int i = 0; i < num_parameters; i++) {
        float param = params_memory[i];
        float grad = grads_memory[i];

        // 更新一阶动量 (momentum)
        float m = beta1 * m_memory[i] + (1.0f - beta1) * grad;
        
        // 更新二阶动量 (RMSprop)
        float v = beta2 * v_memory[i] + (1.0f - beta2) * grad * grad;
        
        // 偏差修正
        float m_hat = m / (1.0f - powf(beta1, t));
        float v_hat = v / (1.0f - powf(beta2, t));

        // 更新参数(包含 weight decay)
        m_memory[i] = m;
        v_memory[i] = v;
        params_memory[i] -= learning_rate * (m_hat / (sqrtf(v_hat) + eps) 
                                             + weight_decay * param);
    }
}

实现要点

2.3 CUDA 基础实现

// naive fused kernel
__global__ void adamw_kernel1(float* params_memory, 
                              const float* grads_memory, 
                              float* m_memory, float* v_memory, 
                              long num_parameters,
                              float learning_rate, float beta1, float beta2, 
                              float beta1_correction, float beta2_correction, 
                              float eps, float weight_decay) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i >= num_parameters) return;  // 边界检查
    
    // 更新一阶动量 (momentum)
    m_memory[i] = beta1 * m_memory[i] + (1.0f - beta1) * grads_memory[i];
    
    // 更新二阶动量 (RMSprop)
    v_memory[i] = beta2 * v_memory[i] + (1.0f - beta2) * grads_memory[i] * grads_memory[i];
    
    // 偏差修正
    float m_hat = m_memory[i] / beta1_correction;
    float v_hat = v_memory[i] / beta2_correction;
    
    // 参数更新
    params_memory[i] -= learning_rate * (m_hat / (sqrtf(v_hat) + eps) 
                                         + weight_decay * params_memory[i]);
}

CUDA 实现特点

  1. 并行化:每个线程处理一个参数
  2. 融合操作:将所有更新步骤融合在一个 kernel 中
  3. 预计算修正因子:beta1_correction = 1 - β₁^t,避免重复计算 pow()

2.4 性能对比

time gpu: 0.0409 ms
time cpu: 0.0612 ms

GPU 相比 CPU 加速约 1.5x,这只是基础实现的性能。


三、CUDA 优化技术分析

虽然原文未详细展开优化部分,但基于 CUDA 优化的通用原则,AdamW 可以从以下方面优化:

3.1 内存访问优化

问题:基础实现中存在多次全局内存访问

优化方向

  1. 合并访问:确保线程访问连续内存
  2. 向量化加载:使用 float4 一次加载 4 个元素
  3. 共享内存:对于需要多次访问的数据使用共享内存缓存

3.2 计算优化

避免昂贵操作

指令级优化

3.3 融合优化

Kernel 融合

多参数组融合

3.4 数值稳定性

防止数值溢出

// 使用 epsilon 防止除零
float denom = sqrtf(v_hat) + eps;

// 梯度裁剪
grad = fminf(fmaxf(grad, -grad_clip), grad_clip);

3.5 混合精度优化

FP16/BF16 计算


四、内存占用分析

4.1 显存占用构成

对于参数量为 N 的模型:

组件 数据类型 大小
参数 (params) FP32 4N bytes
梯度 (grads) FP32 4N bytes
一阶动量 (m) FP32 4N bytes
二阶动量 (v) FP32 4N bytes
总计 16N bytes

示例:GPT-2 (124M 参数)

优化器状态占总显存的 75%,这就是为什么 AdamW 是显存消耗的主要来源。

4.2 内存优化策略

状态分片

混合精度

CPU Offloading


五、实现要点总结

5.1 算法层面

  1. Weight Decay 的正确实现

    • 直接作用于参数更新,不经过动量
    • 与 L2 正则化在 Adam 中不等价
  2. 偏差修正的必要性

    • 训练初期(小 t 值)尤为重要
    • 可以预计算避免重复 pow() 调用
  3. 数值稳定性

    • epsilon 防止除零
    • 梯度裁剪防止爆炸

5.2 CUDA 实现层面

  1. Kernel 设计

    • 每线程一参数的简单映射
    • 融合所有更新步骤减少内存访问
  2. 内存访问模式

    • 确保合并访问
    • 考虑向量化加载
  3. 性能优化

    • 预计算常量
    • 使用快速数学函数
    • 考虑混合精度

5.3 工程实践

  1. 显存管理

    • 优化器状态是显存大头(75%)
    • 考虑状态分片或 CPU offloading
  2. 可扩展性

    • 支持参数分组(不同学习率)
    • 支持梯度累积
  3. 调试和验证

    • 与 PyTorch 实现对比
    • 检查数值精度
    • 性能 profiling

六、参考资料

  1. Karpathy's llm.c - AdamW CUDA Implementation
  2. AdamW优化器简单理解 - CSDN博客
  3. Understanding Weight Decay vs L2 Regularization
  4. Decoupled Weight Decay Regularization (AdamW Paper)

七、扩展阅读

相关优化器

进阶优化

分布式训练


燕子不归春事晚,一汀烟雨杏花寒。 —— 戴叔伦《苏溪亭》