ops3_Cross_Entropy的CUDA实现
ops(3):Cross Entropy 的 CUDA 实现 - 深度分析
原文链接: https://zhuanlan.zhihu.com/p/695594396
发布时间: 2024-05-01
分析日期: 2026-05-05
一、Cross Entropy 损失函数原理
1.1 基本概念
Cross Entropy(交叉熵)是深度学习中最常用的损失函数,特别在多分类任务中与 Softmax 配合使用。其核心作用是衡量预测概率分布与真实标签分布之间的差异。
1.2 数学公式
完整的交叉熵公式为:
cross_entropy = -∑(k=1 to N) [p_k * log(q_k)]
其中:
p表示真实标签值(one-hot 编码形式)q表示预测值(经过 Softmax 后的概率分布,范围在 [0,1])N表示类别总数
1.3 公式简化
在训练过程中,由于真实标签采用 one-hot 编码,只有正确类别位置 p_m^label = 1,其余位置为 0。因此公式可以简化为:
cross_entropy = -log(q_m)
这意味着只需要计算正确类别对应的预测概率的负对数,大大简化了计算复杂度。
二、数值稳定性问题分析
2.1 核心问题
Cross Entropy 计算中涉及对数运算 log(q_m),当预测概率 q_m 接近 0 时会出现数值不稳定:
- 下溢问题:当
q_m → 0时,log(q_m) → -∞,可能导致数值溢出 - 梯度消失:极小的概率值会导致梯度接近 0,影响训练效果
- 精度损失:浮点数表示范围有限,极小值可能被截断为 0
2.2 解决方案
文章采用的策略是将 Cross Entropy 与 Softmax 融合计算,这样可以:
- 避免中间结果:不需要显式存储 Softmax 输出的概率值
- 数值稳定技巧:在 Softmax 计算时使用 log-sum-exp 技巧
- 减少内存访问:融合算子减少了数据在 GPU 内存中的读写次数
2.3 实际影响
从代码实现可以看出,输入 probs 已经是经过 Softmax 处理的概率分布,这意味着:
- 概率值已经归一化到 [0,1] 区间
- 但仍需要注意
log(0)的情况,实际实现中通常会添加一个极小值 epsilon(如 1e-7)
三、CUDA 实现策略
3.1 前向传播实现
CPU 参考实现
void crossentropy_forward_cpu(float* losses,
const float* probs, const int* targets,
int B, int T, int V) {
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
const float* probs_bt = probs + b * T * V + t * V;
int ix = targets[b * T + t];
losses[b * T + t] = -logf(probs_bt[ix]);
}
}
}
关键点:
- 输入维度:
probs为(B, T, V),其中 B=batch_size, T=sequence_length, V=vocab_size - 输出维度:
losses为(B, T),token-level 的损失值 - 计算逻辑:对每个 token 位置,取出目标类别的概率并计算负对数
CUDA 并行化策略
__global__ void crossentropy_forward_kernel1(float* losses,
const float* probs, const int* targets,
int B, int T, int V) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < B * T) {
int b = i / T;
int t = i % T;
const float* probs_bt = probs + b * T * V + t * V;
int ix = targets[b * T + t];
losses[b * T + t] = -logf(probs_bt[ix]);
}
}
并行化维度:
- 在
(B, T)维度上并行,每个线程处理一个 token - 总线程数:
B * T - 每个线程只需要一次内存读取和一次对数计算
性能数据分析:
block_size 32 | time 0.0032 ms | per token 0.39 ns
block_size 64 | time 0.0031 ms | per token 0.38 ns
block_size 128 | time 0.0031 ms | per token 0.38 ns
block_size 256 | time 0.0031 ms | per token 0.38 ns
block_size 512 | time 0.0032 ms | per token 0.39 ns
block_size 1024 | time 0.0037 ms | per token 0.45 ns
性能观察:
- 最优 block_size 为 64-256,达到 0.38 ns/token
- block_size 过小(32)或过大(1024)都会导致性能下降
- 计算极其简单,主要瓶颈在内存访问而非计算
3.2 反向传播实现
数学推导
设 z_i 表示 logits,S(z_i) 表示经过 Softmax 后的概率,损失函数为:
Loss = -ln(S(z_i))
对所有 z_j 求导,需要分情况讨论:
情况1:当 i = j(正确类别)
∂Loss/∂z_j = -1/S(z_i) * ∂S(z_i)/∂z_j
= -1/S(z_i) * S(z_i)(1 - S(z_i))
= S(z_i) - 1
= S(z_j) - 1
情况2:当 i ≠ j(错误类别)
∂Loss/∂z_j = -1/S(z_i) * ∂S(z_i)/∂z_j
= -1/S(z_i) * (-S(z_i) * S(z_j))
= S(z_j)
= S(z_j) - 0
统一表达式:
∂Loss/∂z_j = S(z_j) - indicator(j == i)
其中 indicator 是指示函数,正确类别为 1,其他为 0。
CPU 参考实现
void crossentropy_softmax_backward_cpu(float* dlogits,
const float* dlosses, const float* probs, const int* targets,
int B, int T, int V) {
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dlogits_bt = dlogits + b * T * V + t * V;
const float* probs_bt = probs + b * T * V + t * V;
float dloss = dlosses[b * T + t];
int ix = targets[b * T + t];
for (int i = 0; i < V; i++) {
float p = probs_bt[i];
float indicator = i == ix ? 1.0f : 0.0f;
dlogits_bt[i] += (p - indicator) * dloss;
}
}
}
}
CUDA 并行化策略
__global__ void crossentropy_softmax_backward_kernel1(float* dlogits,
const float* dlosses, const float* probs, const int* targets,
int B, int T, int V) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < B * T * V) {
int b = i / (T * V);
int t = (i / V) % T;
int v = i % V;
float* dlogits_bt = dlogits + b * T * V + t * V;
const float* probs_bt = probs + b * T * V + t * V;
float dloss = dlosses[b * T + t];
int ix = targets[b * T + t];
float p = probs_bt[v];
float indicator = v == ix ? 1.0f : 0.0f;
dlogits_bt[v] += (p - indicator) * dloss;
}
}
并行化维度:
- 在
(B, T, V)三个维度上全并行 - 总线程数:
B * T * V - 每个线程处理一个 logit 的梯度
性能数据分析:
block_size 32 | time 20.2376 ms | per token 2.47 µs
block_size 64 | time 10.0498 ms | per token 1.23 µs
block_size 128 | time 6.2755 ms | per token 0.77 µs
block_size 256 | time 6.2235 ms | per token 0.76 µs
block_size 512 | time 6.2832 ms | per token 0.77 µs
block_size 1024 | time 6.5979 ms | per token 0.81 µs
性能观察:
- 最优 block_size 为 256,达到 0.76 µs/token
- 相比前向传播慢约 2000 倍(0.38 ns vs 0.76 µs)
- 原因:反向传播需要处理完整的 vocab 维度,计算量为
O(B*T*V)
四、与 SoftMax 的融合优化
4.1 为什么要融合
文章明确指出:"由于计算过于简单,通常与 softmax 一并操作"。融合的优势包括:
-
减少内存访问:
- 分离实现:Softmax 输出 → 写入内存 → Cross Entropy 读取
- 融合实现:Softmax 计算后直接用于 Cross Entropy,无需中间存储
-
提高数值稳定性:
- 可以使用 log-softmax 技巧:
log(softmax(x)) = x - log(sum(exp(x))) - 避免先计算
exp再取log的精度损失
- 可以使用 log-softmax 技巧:
-
简化反向传播:
- 融合后的梯度公式更简洁:
dL/dz = softmax(z) - y - 减少中间梯度的存储和传递
- 融合后的梯度公式更简洁:
4.2 融合实现的数学原理
标准流程:
logits → Softmax → probs → Cross Entropy → loss
融合流程:
logits → Fused Softmax-CrossEntropy → loss
数值稳定的 log-softmax:
log_softmax(x_i) = x_i - log(∑_j exp(x_j))
= x_i - max(x) - log(∑_j exp(x_j - max(x)))
融合的 Cross Entropy:
loss = -log_softmax(x_target)
= -(x_target - max(x) - log(∑_j exp(x_j - max(x))))
4.3 反向传播的融合优势
从推导可以看出,Softmax + Cross Entropy 的联合梯度非常简洁:
∂Loss/∂logits = probs - targets
这个结果比分别计算 Cross Entropy 和 Softmax 的梯度再链式求导要简单得多,这也是为什么实际框架中都采用融合实现的原因。
融合优化的关键技巧:
-
Log-Sum-Exp 技巧:避免数值上溢/下溢
log(∑ exp(x_i)) = max(x) + log(∑ exp(x_i - max(x))) -
一次遍历计算:在同一个 kernel 中完成 softmax 和 cross entropy
- 减少全局内存访问
- 提高缓存命中率
- 降低延迟
-
梯度简化:直接计算
probs - targets,无需链式法则
五、性能分析与优化建议
5.1 前向传播性能分析
当前性能:0.38 ns/token(最优配置)
瓶颈分析:
- 计算强度低:只有一次
log运算 - 内存访问模式:随机访问(通过
targets[i]索引) - 优化空间有限:计算已经足够简单
可能的优化方向:
- 向量化加载:使用
float4等向量类型(但由于是随机访问,效果有限) - 与 Softmax 融合:减少内存往返
- 使用 Tensor Core:对于大 batch 可以考虑矩阵化操作
5.2 反向传播性能分析
当前性能:0.76 µs/token(最优配置)
瓶颈分析:
- 计算量大:需要处理完整的
V维度(通常 V=50000+) - 内存带宽受限:需要读取完整的
probs数组 - 写入冲突:多个线程可能写入同一个
dlogits位置(需要原子操作)
优化建议:
-
内存访问优化:
- 使用共享内存缓存
probs_bt和dloss - 合并内存访问,提高带宽利用率
- 使用共享内存缓存
-
计算优化:
- 向量化操作:使用
float4一次处理 4 个元素 - 循环展开:减少分支预测开销
- 向量化操作:使用
-
并行策略优化:
- 考虑在
V维度上使用 warp-level 并行 - 使用 block-level reduction 减少原子操作
- 考虑在
-
融合优化:
- 与 Softmax 反向传播融合,共享中间结果
- 减少一次完整的内存读写
5.3 实际应用中的考虑
Vocabulary Size 的影响:
- 小词表(V < 10000):当前实现已经足够高效
- 大词表(V > 50000):反向传播成为瓶颈,建议:
- 使用 Adaptive Softmax
- 采用 Sampled Softmax 或 Negative Sampling
- 考虑 Hierarchical Softmax
Batch Size 的影响:
- 小 batch:GPU 利用率不足,考虑增加 batch size
- 大 batch:内存带宽成为瓶颈,优化内存访问模式
六、代码实现要点总结
6.1 前向传播关键点
- 输入验证:确保
probs已经归一化,值域在 [0,1] - 数值稳定:虽然代码中未显式处理,但实际应用中应添加 epsilon
- 并行粒度:在
(B, T)维度并行,每个线程处理一个 token - 内存布局:使用行优先存储,保证内存访问的连续性
6.2 反向传播关键点
- 梯度累加:使用
+=而非=,支持梯度累积 - 指示函数:通过条件判断实现 one-hot 编码
- 并行粒度:在
(B, T, V)全维度并行 - 原子操作:当前实现使用
+=,在某些情况下可能需要原子操作
6.3 工程实践建议
-
类型选择:
- 训练:使用
float32保证精度 - 推理:可以考虑
float16或bfloat16加速
- 训练:使用
-
错误处理:
- 检查
targets的有效性(0 <= targets < V) - 处理 NaN 和 Inf 的情况
- 检查
-
性能监控:
- 使用 CUDA Events 精确测量时间
- 监控 GPU 利用率和内存带宽
-
可扩展性:
- 支持不同的 reduction 方式(mean, sum)
- 支持 label smoothing
- 支持 class weights
七、与主流框架的对比
7.1 PyTorch 实现
PyTorch 的 nn.CrossEntropyLoss 内部实现了 Softmax + Cross Entropy 的融合:
# PyTorch 用法
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, targets) # logits 未经过 softmax
优势:
- 自动融合,数值稳定
- 支持多种 reduction 模式
- 支持 label smoothing 和 class weights
7.2 本文实现的特点
优势:
- 代码简洁,易于理解
- 适合教学和学习 CUDA 编程
- 性能数据透明,便于分析
局限:
- 未实现完整的融合优化
- 缺少数值稳定性保护
- 未支持高级特性(label smoothing 等)
八、进阶优化方向
8.1 算法层面
- Adaptive Softmax:对高频词和低频词使用不同的处理策略
- Sampled Softmax:只在部分负样本上计算损失
- Label Smoothing:平滑 one-hot 标签,提高泛化能力
8.2 系统层面
- 混合精度训练:使用 FP16 加速,关键部分保持 FP32
- 梯度检查点:减少内存占用,允许更大的 batch size
- 分布式训练:在多 GPU 上并行计算
8.3 硬件层面
- Tensor Core 利用:将操作转换为矩阵乘法
- 异步执行:使用 CUDA Streams 重叠计算和通信
- 内存优化:使用 Unified Memory 或 Zero-Copy Memory
九、参考资源
9.1 原文引用
- karpathy/llm.c - crossentropy_forward.cu
- karpathy/llm.c - crossentropy_softmax_backward.cu
- Cross Entropy Loss 的并行化方案
9.2 作者实现
- ifromeast/cuda_learning - crossentropy_forward.cu
- ifromeast/cuda_learning - crossentropy_softmax_backward.cu
十、总结
本文详细分析了 Cross Entropy 损失函数的 CUDA 实现,主要收获包括:
10.1 核心要点
- 原理理解:Cross Entropy 在 one-hot 标签下可以简化为
-log(q_target) - 数值稳定性:与 Softmax 融合可以避免中间结果的精度损失
- 并行策略:前向在 (B,T) 并行,反向在 (B,T,V) 全并行
- 性能特征:前向极快(0.38 ns/token),反向较慢(0.76 µs/token)
- 优化方向:融合算子、内存优化、算法改进
10.2 数值稳定性深度分析
问题根源:
log(0)会产生-∞- 极小的概率值(如 1e-40)在
log后会产生极大的负数 - Softmax 中的
exp运算容易上溢或下溢
解决方案层次:
-
基础保护:添加 epsilon
losses[i] = -logf(probs[ix] + 1e-7f); -
融合优化:使用 log-softmax
// 不计算 softmax 再取 log,而是直接计算 log-softmax log_softmax[i] = logits[i] - max_logit - log(sum(exp(logits - max_logit))) loss = -log_softmax[target] -
数学等价变换:
原始:loss = -log(exp(x_i) / sum(exp(x_j))) 变换:loss = -x_i + log(sum(exp(x_j))) 稳定:loss = -x_i + max(x) + log(sum(exp(x_j - max(x))))
10.3 融合优化的深层价值
内存层面:
- 减少全局内存读写:从 2 次(写 probs + 读 probs)降到 0 次
- 提高缓存利用率:中间结果保持在寄存器或共享内存
- 降低带宽压力:对于大词表(V=50000),节省 200KB/token 的传输
计算层面:
- 避免冗余计算:不需要完整计算所有类别的 softmax
- 简化梯度计算:
probs - targets比链式法则更高效 - 减少数值运算:少一次
exp和log的组合
数值层面:
- 消除中间舍入误差:浮点运算的每次存储都会引入误差
- 保持更高精度:log-softmax 直接计算避免了
exp的动态范围问题
10.4 实践建议
这是一个简洁而高效的实现,适合作为学习 CUDA 编程和深度学习算子优化的入门案例。在实际生产环境中,建议:
- 使用成熟框架:PyTorch、TensorFlow 提供的融合实现更稳定
- 关注数值稳定性:始终使用 log-softmax 而非 softmax + log
- 针对性优化:根据词表大小选择合适的算法(Adaptive/Sampled Softmax)
- 性能监控:持续监控 GPU 利用率和内存带宽,识别瓶颈
10.5 延伸思考
-
为什么反向传播慢 2000 倍?
- 前向:只计算一个元素(target 位置)
- 反向:需要计算所有 V 个元素的梯度
- 计算量比:1 vs V(通常 V=50000)
-
如何进一步优化?
- 使用 warp shuffle 减少共享内存使用
- 采用 block-strided loop 提高指令级并行
- 考虑使用 CUB 库的高效 reduction 原语
-
在 Transformer 中的位置:
- 通常是最后一层的输出
- 只在训练时计算,推理时可以跳过
- 可以与 LayerNorm、Linear 层进一步融合
分析完成时间: 2026-05-05
文档版本: v1.0
分析深度: 详细分析了数值稳定性问题和融合优化技巧