ops2_SoftMax算子的CUDA实现

ops(2):SoftMax算子的 CUDA 实现

原文链接: https://zhuanlan.zhihu.com/p/695307283
发布时间: 2024-12-19
文章ID: 695307283


文章概述

本文深入探讨了 SoftMax 算子的 CUDA 实现与优化,从基础的 Safe SoftMax 到高级的 Online SoftMax,涵盖了多种优化技术。文章通过 7 个版本的迭代优化,展示了如何将性能从 28.32 µs/token 提升到 1.37 µs/token,实现了约 20 倍的性能提升。


一、SoftMax 算法与数值稳定性

1.1 基本原理

SoftMax 是神经网络中的基本激活函数,用于将数值向量归一化为概率分布,广泛应用于多分类问题和 self-attention 机制。

标准 SoftMax 公式

Softmax(x_i) = e^(x_i) / Σ_j e^(x_j)

1.2 数值稳定性问题

直接计算 SoftMax 存在数值溢出风险。当输入值较大时,e^(x_i) 可能超出浮点数表示范围。

Safe SoftMax 解决方案

m = max(x)
Softmax(x_i) = e^(x_i - m) / Σ_j e^(x_j - m)

关键改进

1.3 计算流程分析

Safe SoftMax 需要 3 个循环

  1. 第一次循环:计算最大值 m = max(x)
  2. 第二次循环:计算指数和 sum = Σ e^(x_i - m)
  3. 第三次循环:归一化 a_i = e^(x_i - m) / sum

性能瓶颈


二、Online SoftMax 算法

2.1 算法原理

Online SoftMax 是 FlashAttention 的核心技术之一,通过迭代方式减少循环次数。

关键创新:将 3 个循环优化为 2 个循环

迭代公式

for i ← 1, N do
    m_i ← max(m_{i-1}, x_i)
    sum'_i ← sum'_{i-1} * e^(m_{i-1} - m_i) + e^(x_i - m_i)
end

for i ← 1, N do
    a_i ← e^(x_i - m_N) / sum'_N
end

2.2 数学推导

核心是 sum'_i 的增量更新:

sum'_i = Σ_{j=1}^i e^(x_j - m_i)
       = (Σ_{j=1}^{i-1} e^(x_j - m_i)) + e^(x_i - m_i)
       = (Σ_{j=1}^{i-1} e^(x_j - m_{i-1})) * e^(m_{i-1} - m_i) + e^(x_i - m_i)
       = sum'_{i-1} * e^(m_{i-1} - m_i) + e^(x_i - m_i)

优势


三、CUDA 实现与优化历程

版本对比总结

版本 方法 Block Size 性能 (µs/token) 加速比
V1 Safe SoftMax + Global Memory 32 7.53 1.0x
V1 Safe SoftMax + Global Memory 1024 28.32 0.27x
V2 Safe SoftMax + Shared Memory 1024 1.83 4.1x
V3 Safe SoftMax + Warp Reduce 32 2.95 2.6x
V4 Safe SoftMax + Warp + Shared 1024 1.97 3.8x
V5 Online SoftMax + Global Memory 32 5.80 1.3x
V6 Online SoftMax + 协作组 32 1.37 5.5x
V7 Online SoftMax + Unroll 1024 1.47 5.1x

3.1 V1:基础实现(Global Memory)

实现特点

性能问题

性能数据

3.2 V2:共享内存优化

核心技术

实现细节

extern __shared__ float shared[];
// Thread coarsening: 每个线程处理 C/block_size 个元素
for (int i = tid; i < C; i += block_size) {
    maxval = fmaxf(maxval, x[i]);
}
shared[tid] = maxval;

// Reduction: 树形归约
for (int stride = block_size / 2; stride >= 1; stride /= 2) {
    if (tid < stride) {
        shared[tid] = fmaxf(shared[tid], shared[tid + stride]);
    }
}

性能提升

3.3 V3:Warp 级归约优化

Warp 基础知识

核心函数__shfl_down_sync

Warp Reduce 实现

__device__ float warpReduceMax(float val) {
    for (int offset = 16; offset > 0; offset /= 2) {
        val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
    }
    return val;
}

__device__ float warpReduceSum(float val) {
    for (int offset = 16; offset > 0; offset /= 2) {
        val += __shfl_down_sync(0xFFFFFFFF, val, offset);
    }
    return val;
}

工作原理

限制

3.4 V4:Warp + Shared Memory 混合优化

设计思路

两级归约架构

int warpId = threadIdx.x / 32;
int laneId = threadIdx.x % 32;
int warpsPerBlock = blockDim.x / 32;

// 第一级:Warp 内归约
maxval = warpReduceMax(maxval);
if (laneId == 0) maxvals[warpId] = maxval;  // 每个 warp 的结果写入 shared memory

// 第二级:Warp 间归约
if (tid == 0) {
    float val = maxvals[0];
    for (int i = 1; i < warpsPerBlock; i++) {
        val = fmaxf(val, maxvals[i]);
    }
    maxvals[0] = val;
}

性能数据

3.5 V5:Online SoftMax 基础实现

算法切换:从 Safe SoftMax 改为 Online SoftMax

实现代码

float maxval = -INFINITY;
double sum = 0.0;
for (int j = 0; j < C; j++) {
    float maxval_prev = maxval;
    if (inp_row[j] > maxval) {
        maxval = inp_row[j];
        sum = sum * expf(maxval_prev - maxval) + expf(inp_row[j] - maxval);
    } else {
        sum += expf(inp_row[j] - maxval);
    }
}

性能问题

3.6 V6:协作组 + 结构体(最优方案)

核心创新

  1. 结构体封装
struct __align__(8) SumMax {
    float maxval;
    float sum;
};
  1. 统一的 reduce 操作
__device__ __forceinline__ SumMax reduce_sum_max_op(SumMax a, SumMax b) {
    bool a_bigger = (a.maxval > b.maxval);
    SumMax bigger_m = a_bigger ? a : b;
    SumMax smaller_m = a_bigger ? b : a;
    SumMax res;
    res.maxval = bigger_m.maxval;
    res.sum = bigger_m.sum + smaller_m.sum * expf(smaller_m.maxval - bigger_m.maxval);
    return res;
}
  1. 协作组 API
namespace cg = cooperative_groups;
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);

// 一次 reduce 同时完成 max 和 sum
SumMax sm_total = cg::reduce(warp, sm_partial, reduce_sum_max_op);

关键优势

性能突破

3.7 V7:向量化访问优化

目标场景:C 值非常大的情况

核心技术

  1. 循环展开#pragma unroll
  2. 寄存器数组:减少内存访问
  3. 向量化读取:一次处理 8 个元素

实现细节

const int UNROLL_FACTOR = 8;

// 向量化读取
for (int i = tid; i < C; i += blockDim.x * UNROLL_FACTOR) {
    float reg_array[UNROLL_FACTOR];
    #pragma unroll
    for (int u = 0; u < UNROLL_FACTOR; u++) {
        reg_array[u] = __ldcs(&x[min(C - 1, i + u*blockDim.x)]);
    }
    // 处理 reg_array
}

特殊指令

性能数据


四、性能优化技术总结

4.1 内存层次优化

内存访问延迟对比

优化策略

  1. 最大化寄存器使用:循环展开、寄存器数组
  2. 利用 Shared Memory:缓存频繁访问的数据
  3. 减少 Global Memory 访问:Thread coarsening
  4. 绕过缓存:使用 __ldcs__stcs 指令

4.2 并行归约技术

三种归约方式

  1. Shared Memory 归约(V2):

    • 树形归约,log(N) 步
    • 需要 __syncthreads() 同步
    • 适合 block 级归约
  2. Warp 归约(V3):

    • 使用 __shfl_down_sync
    • 无需同步,延迟最低
    • 仅限 warp 内(32 线程)
  3. 混合归约(V4):

    • Warp 内 + Warp 间
    • 灵活支持大 block size
    • 性能与灵活性的平衡

4.3 向量化访问

技术要点

适用场景

4.4 指令级优化

特殊指令

4.5 数值稳定性保证

Safe SoftMax 技术

Online SoftMax 稳定性


五、性能测试与分析

5.1 测试环境

5.2 关键性能数据

最佳性能对比

性能提升

5.3 Block Size 影响分析

V1(Global Memory)

V2-V4(优化版本)

V6(协作组)


六、实践建议

6.1 版本选择指南

推荐使用 V6(协作组版本)

特殊场景

6.2 优化原则

  1. 先保证正确性:数值稳定性优先
  2. 内存优化优先:减少 global memory 访问
  3. 充分利用硬件:Warp、Shared Memory、特殊指令
  4. 测试驱动优化:每次优化都要测量性能
  5. 权衡复杂度:性能提升要与代码复杂度平衡

6.3 扩展应用

FlashAttention 启示

其他算子优化


七、代码资源

源代码仓库

参考资料

  1. NVIDIA 官方文档:使用 CUDA 扭曲级别基本体
  2. OneFlow 技术博客:如何实现一个高效的Softmax CUDA kernel

八、核心要点总结

数值稳定性

算法优化

内存优化

并行优化

性能提升


结语:本文展示了从基础实现到高度优化的完整过程,体现了 CUDA 编程中内存层次、并行模式、硬件特性的深度结合。这些优化技术不仅适用于 SoftMax,也为其他算子优化提供了宝贵的参考。