ops2_SoftMax算子的CUDA实现
ops(2):SoftMax算子的 CUDA 实现
原文链接: https://zhuanlan.zhihu.com/p/695307283
发布时间: 2024-12-19
文章ID: 695307283
文章概述
本文深入探讨了 SoftMax 算子的 CUDA 实现与优化,从基础的 Safe SoftMax 到高级的 Online SoftMax,涵盖了多种优化技术。文章通过 7 个版本的迭代优化,展示了如何将性能从 28.32 µs/token 提升到 1.37 µs/token,实现了约 20 倍的性能提升。
一、SoftMax 算法与数值稳定性
1.1 基本原理
SoftMax 是神经网络中的基本激活函数,用于将数值向量归一化为概率分布,广泛应用于多分类问题和 self-attention 机制。
标准 SoftMax 公式:
Softmax(x_i) = e^(x_i) / Σ_j e^(x_j)
1.2 数值稳定性问题
直接计算 SoftMax 存在数值溢出风险。当输入值较大时,e^(x_i) 可能超出浮点数表示范围。
Safe SoftMax 解决方案:
m = max(x)
Softmax(x_i) = e^(x_i - m) / Σ_j e^(x_j - m)
关键改进:
- 先减去最大值 m,确保指数运算的输入为非正数
- 避免了数值溢出,同时保持数学等价性
- 这是工业界标准做法
1.3 计算流程分析
Safe SoftMax 需要 3 个循环:
- 第一次循环:计算最大值
m = max(x) - 第二次循环:计算指数和
sum = Σ e^(x_i - m) - 第三次循环:归一化
a_i = e^(x_i - m) / sum
性能瓶颈:
- 需要全局信息(最大值和总和)
- 多次遍历数据,增加内存访问开销
- 在大规模数据处理中效率较低
二、Online SoftMax 算法
2.1 算法原理
Online SoftMax 是 FlashAttention 的核心技术之一,通过迭代方式减少循环次数。
关键创新:将 3 个循环优化为 2 个循环
迭代公式:
for i ← 1, N do
m_i ← max(m_{i-1}, x_i)
sum'_i ← sum'_{i-1} * e^(m_{i-1} - m_i) + e^(x_i - m_i)
end
for i ← 1, N do
a_i ← e^(x_i - m_N) / sum'_N
end
2.2 数学推导
核心是 sum'_i 的增量更新:
sum'_i = Σ_{j=1}^i e^(x_j - m_i)
= (Σ_{j=1}^{i-1} e^(x_j - m_i)) + e^(x_i - m_i)
= (Σ_{j=1}^{i-1} e^(x_j - m_{i-1})) * e^(m_{i-1} - m_i) + e^(x_i - m_i)
= sum'_{i-1} * e^(m_{i-1} - m_i) + e^(x_i - m_i)
优势:
- 无需全局信息即可迭代计算
- 减少一次数据遍历
- 支持流式处理和分块计算
- 为 FlashAttention 的 IO 优化奠定基础
三、CUDA 实现与优化历程
版本对比总结
| 版本 | 方法 | Block Size | 性能 (µs/token) | 加速比 |
|---|---|---|---|---|
| V1 | Safe SoftMax + Global Memory | 32 | 7.53 | 1.0x |
| V1 | Safe SoftMax + Global Memory | 1024 | 28.32 | 0.27x |
| V2 | Safe SoftMax + Shared Memory | 1024 | 1.83 | 4.1x |
| V3 | Safe SoftMax + Warp Reduce | 32 | 2.95 | 2.6x |
| V4 | Safe SoftMax + Warp + Shared | 1024 | 1.97 | 3.8x |
| V5 | Online SoftMax + Global Memory | 32 | 5.80 | 1.3x |
| V6 | Online SoftMax + 协作组 | 32 | 1.37 | 5.5x |
| V7 | Online SoftMax + Unroll | 1024 | 1.47 | 5.1x |
3.1 V1:基础实现(Global Memory)
实现特点:
- 在 N 维度上并行,每个线程处理一行
- 完全使用 global memory
- 直接翻译 CPU 代码
性能问题:
- Global memory 访问延迟高(400-800 cycles)
- 无缓存优化
- Block size 增大时性能急剧下降
性能数据:
- Block size 32: 7.53 µs/token
- Block size 1024: 28.32 µs/token(性能下降 3.8 倍)
3.2 V2:共享内存优化
核心技术:
- 每个 block 处理一行数据(长度 C)
- 使用动态共享内存缓存中间结果
- 实现 reduce 操作求最大值和总和
实现细节:
extern __shared__ float shared[];
// Thread coarsening: 每个线程处理 C/block_size 个元素
for (int i = tid; i < C; i += block_size) {
maxval = fmaxf(maxval, x[i]);
}
shared[tid] = maxval;
// Reduction: 树形归约
for (int stride = block_size / 2; stride >= 1; stride /= 2) {
if (tid < stride) {
shared[tid] = fmaxf(shared[tid], shared[tid + stride]);
}
}
性能提升:
- Block size 1024: 从 28.32 µs 降至 1.83 µs
- 提升 15.5 倍
- Shared memory 延迟仅 20-40 cycles
3.3 V3:Warp 级归约优化
Warp 基础知识:
- Warp 是 GPU 调度的基本单位,包含 32 个线程
- Warp 内线程执行 SIMT(单指令多线程)
- 支持高效的线程间通信
核心函数:__shfl_down_sync
Warp Reduce 实现:
__device__ float warpReduceMax(float val) {
for (int offset = 16; offset > 0; offset /= 2) {
val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
}
return val;
}
__device__ float warpReduceSum(float val) {
for (int offset = 16; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xFFFFFFFF, val, offset);
}
return val;
}
工作原理:
__shfl_down_sync在 warp 内直接交换寄存器数据- 无需共享内存,延迟更低
- 5 次迭代完成 32 个线程的归约(log₂(32) = 5)
限制:
- 必须使用 block_size = 32
- 性能 2.95 µs/token,未达到 V2 水平
3.4 V4:Warp + Shared Memory 混合优化
设计思路:
- Warp 内使用
__shfl_down_sync归约 - Warp 间使用 shared memory 归约
- 支持任意 block size(32 的倍数)
两级归约架构:
int warpId = threadIdx.x / 32;
int laneId = threadIdx.x % 32;
int warpsPerBlock = blockDim.x / 32;
// 第一级:Warp 内归约
maxval = warpReduceMax(maxval);
if (laneId == 0) maxvals[warpId] = maxval; // 每个 warp 的结果写入 shared memory
// 第二级:Warp 间归约
if (tid == 0) {
float val = maxvals[0];
for (int i = 1; i < warpsPerBlock; i++) {
val = fmaxf(val, maxvals[i]);
}
maxvals[0] = val;
}
性能数据:
- Block size 1024: 1.97 µs/token
- 结合了 warp 和 shared memory 的优势
- 灵活支持不同 block size
3.5 V5:Online SoftMax 基础实现
算法切换:从 Safe SoftMax 改为 Online SoftMax
实现代码:
float maxval = -INFINITY;
double sum = 0.0;
for (int j = 0; j < C; j++) {
float maxval_prev = maxval;
if (inp_row[j] > maxval) {
maxval = inp_row[j];
sum = sum * expf(maxval_prev - maxval) + expf(inp_row[j] - maxval);
} else {
sum += expf(inp_row[j] - maxval);
}
}
性能问题:
- 5.80 µs/token,反而比 V1 慢
- 原因:增加了分支判断和额外的 exp 计算
- 未充分利用 GPU 并行性
3.6 V6:协作组 + 结构体(最优方案)
核心创新:
- 结构体封装:
struct __align__(8) SumMax {
float maxval;
float sum;
};
- 统一的 reduce 操作:
__device__ __forceinline__ SumMax reduce_sum_max_op(SumMax a, SumMax b) {
bool a_bigger = (a.maxval > b.maxval);
SumMax bigger_m = a_bigger ? a : b;
SumMax smaller_m = a_bigger ? b : a;
SumMax res;
res.maxval = bigger_m.maxval;
res.sum = bigger_m.sum + smaller_m.sum * expf(smaller_m.maxval - bigger_m.maxval);
return res;
}
- 协作组 API:
namespace cg = cooperative_groups;
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
// 一次 reduce 同时完成 max 和 sum
SumMax sm_total = cg::reduce(warp, sm_partial, reduce_sum_max_op);
关键优势:
- 将 max 和 sum 两个 reduce 合并为一个
- 使用
__stcs指令绕过缓存直接写入 - 协作组提供更高层次的抽象
性能突破:
- 1.37 µs/token
- 相比 V1 提升 5.5 倍
- 相比 V4 提升 1.44 倍
- 所有 block size 性能稳定
3.7 V7:向量化访问优化
目标场景:C 值非常大的情况
核心技术:
- 循环展开:
#pragma unroll - 寄存器数组:减少内存访问
- 向量化读取:一次处理 8 个元素
实现细节:
const int UNROLL_FACTOR = 8;
// 向量化读取
for (int i = tid; i < C; i += blockDim.x * UNROLL_FACTOR) {
float reg_array[UNROLL_FACTOR];
#pragma unroll
for (int u = 0; u < UNROLL_FACTOR; u++) {
reg_array[u] = __ldcs(&x[min(C - 1, i + u*blockDim.x)]);
}
// 处理 reg_array
}
特殊指令:
__ldcs: 绕过 L1 缓存的流式加载- 适合顺序访问大数据
性能数据:
- Block size 1024: 1.47 µs/token
- 在大 C 值时优于 V6
- 小 C 值时略逊于 V6
四、性能优化技术总结
4.1 内存层次优化
内存访问延迟对比:
- 寄存器:1 cycle
- Shared Memory:20-40 cycles
- L1 Cache:~80 cycles
- Global Memory:400-800 cycles
优化策略:
- 最大化寄存器使用:循环展开、寄存器数组
- 利用 Shared Memory:缓存频繁访问的数据
- 减少 Global Memory 访问:Thread coarsening
- 绕过缓存:使用
__ldcs、__stcs指令
4.2 并行归约技术
三种归约方式:
-
Shared Memory 归约(V2):
- 树形归约,log(N) 步
- 需要
__syncthreads()同步 - 适合 block 级归约
-
Warp 归约(V3):
- 使用
__shfl_down_sync - 无需同步,延迟最低
- 仅限 warp 内(32 线程)
- 使用
-
混合归约(V4):
- Warp 内 + Warp 间
- 灵活支持大 block size
- 性能与灵活性的平衡
4.3 向量化访问
技术要点:
- 循环展开:
#pragma unroll减少循环开销 - 寄存器数组:一次加载多个元素
- 内存合并访问:提高带宽利用率
- 流式加载:
__ldcs适合顺序访问
适用场景:
- C 值较大(> 1024)
- 顺序访问模式
- 内存带宽受限
4.4 指令级优化
特殊指令:
__shfl_down_sync:Warp 内数据交换__ldcs:绕过 L1 的流式加载__stcs:绕过缓存的流式存储__forceinline__:强制内联减少函数调用开销
4.5 数值稳定性保证
Safe SoftMax 技术:
- 减去最大值避免溢出
- 使用 double 精度累加 sum
- 保持数学等价性
Online SoftMax 稳定性:
- 动态更新最大值
- 指数差值始终非正
- 增量更新保证精度
五、性能测试与分析
5.1 测试环境
- 数据规模:N × C 矩阵
- 测试指标:每个 token 的处理时间(µs/token)
- Block size 范围:32, 64, 128, 256, 512, 1024
5.2 关键性能数据
最佳性能对比:
- V1(基础):7.53 µs/token(block_size=32)
- V2(Shared Memory):1.83 µs/token(block_size=1024)
- V4(Warp+Shared):1.97 µs/token(block_size=1024)
- V6(协作组):1.37 µs/token(block_size=32)✓ 最优
- V7(向量化):1.47 µs/token(block_size=1024)
性能提升:
- V2 相比 V1:4.1 倍
- V6 相比 V1:5.5 倍
- V6 相比 V4:1.44 倍
5.3 Block Size 影响分析
V1(Global Memory):
- Block size 增大,性能急剧下降
- 1024 时性能下降到 28.32 µs/token
- 原因:内存访问冲突增加
V2-V4(优化版本):
- Block size 增大,性能提升
- 更好的并行度和资源利用
- 1024 时达到最佳性能
V6(协作组):
- 所有 block size 性能稳定
- 1.37-1.43 µs/token 范围
- 架构设计更优
六、实践建议
6.1 版本选择指南
推荐使用 V6(协作组版本):
- 性能最优(1.37 µs/token)
- 代码简洁,易于维护
- 对 block size 不敏感
- 适合生产环境
特殊场景:
- C 值极大(> 2048):考虑 V7(向量化)
- 资源受限:使用 V4(Warp+Shared)
- 学习目的:从 V1 到 V6 逐步理解
6.2 优化原则
- 先保证正确性:数值稳定性优先
- 内存优化优先:减少 global memory 访问
- 充分利用硬件:Warp、Shared Memory、特殊指令
- 测试驱动优化:每次优化都要测量性能
- 权衡复杂度:性能提升要与代码复杂度平衡
6.3 扩展应用
FlashAttention 启示:
- Online SoftMax 支持分块计算
- 减少 HBM 访问次数
- 在 SRAM 中完成更多计算
- 适合长序列处理
其他算子优化:
- LayerNorm:类似的归约操作
- Attention:结合 Online SoftMax
- Reduction 算子:通用优化模式
七、代码资源
源代码仓库:
- 作者实现:https://github.com/ifromeast/cuda_learning/blob/main/04_transformer/ops/softmax_forward.cu
- 原始代码:https://github.com/karpathy/llm.c/blob/master/dev/cuda/softmax_forward.cu
参考资料:
- NVIDIA 官方文档:使用 CUDA 扭曲级别基本体
- OneFlow 技术博客:如何实现一个高效的Softmax CUDA kernel
八、核心要点总结
数值稳定性
- Safe SoftMax 通过减去最大值避免溢出
- Online SoftMax 通过增量更新保持稳定性
- 使用 double 精度累加提高精度
算法优化
- Online SoftMax 将 3 个循环减少到 2 个
- 支持流式处理和分块计算
- 为 FlashAttention 等高级优化奠定基础
内存优化
- Shared Memory 相比 Global Memory 提升 15 倍
- Warp 级操作延迟最低
- 向量化访问提高带宽利用率
并行优化
- Thread coarsening 提高并行度
- 两级归约(Warp 内 + Warp 间)
- 协作组提供更高层次抽象
性能提升
- 从 28.32 µs/token 优化到 1.37 µs/token
- 总体提升约 20 倍
- V6 协作组版本为最优方案
结语:本文展示了从基础实现到高度优化的完整过程,体现了 CUDA 编程中内存层次、并行模式、硬件特性的深度结合。这些优化技术不仅适用于 SoftMax,也为其他算子优化提供了宝贵的参考。