CUDA3_通用矩阵乘法_从入门到熟练

CUDA(三):通用矩阵乘法:从入门到熟练

作者: 紫气东来
发布时间: 2024-04-26
原文链接: https://zhuanlan.zhihu.com/p/657632577


文章概述

本文深入探讨了通用矩阵乘法(GEMM)在CUDA中的实现与优化,从最基础的朴素实现逐步优化到接近cuBLAS库的性能水平。GEMM是深度学习和高性能计算中的核心操作,也是评估GPU计算能力的标准基准。

一、GEMM的重要性和应用场景

1.1 GEMM的定义

GEMM的数学定义为:C ← αAB + βC

其中:

1.2 计算复杂度分析

1.3 应用场景

二、朴素实现及其问题

2.1 CPU参考实现

#define OFFSET(row, col, ld) ((row) * (ld) + (col))

void cpuSgemm(float *a, float *b, float *c, const int M, const int N, const int K) {
    for (int m = 0; m < M; m++) {
        for (int n = 0; n < N; n++) {
            float psum = 0.0;
            for (int k = 0; k < K; k++) {
                psum += a[OFFSET(m, k, K)] * b[OFFSET(k, n, N)];
            }
            c[OFFSET(m, n, N)] = psum;
        }
    }
}

2.2 朴素CUDA实现

__global__ void naiveSgemm(
    float * __restrict__ a, float * __restrict__ b, float * __restrict__ c,
    const int M, const int N, const int K) {
    
    int n = blockIdx.x * blockDim.x + threadIdx.x;
    int m = blockIdx.y * blockDim.y + threadIdx.y;
    if (m < M && n < N) {
        float psum = 0.0;
        #pragma unroll
        for (int k = 0; k < K; k++) {
            psum += a[OFFSET(m, k, K)] * b[OFFSET(k, n, N)];
        }
        c[OFFSET(m, n, N)] = psum;
    }
}

2.3 朴素实现的性能问题

测试环境: Tesla V100-PCIE-32GB (FP32峰值算力15.7 TFLOPS)

性能数据:

问题分析:

  1. 计算访存比极低: 每次乘累加需要两次Global Memory读取
  2. 内存访问效率低: 大量重复读取相同数据
  3. 计算访存比: 2KMN/(KMN/32 × 5 × 4) = 3.2 OP/byte
  4. 理论性能上限: 仅约2.4 TFLOPS(受带宽限制)

2.4 执行流程分析(以M=N=K=512为例)

  1. 每个thread负责计算C矩阵中的1个元素
  2. 需要从A读取一行(K个元素),从B读取一列(K个元素)
  3. 执行K次乘累加运算
  4. 同一warp内的32个thread可以部分合并访存
  5. 每个warp每次循环需要5次transaction(1次读A,4次读B)

三、优化方法详解

3.1 优化一:使用Shared Memory进行分块(Tiling)

核心思想: 将矩阵分块,利用Shared Memory减少Global Memory访问

分块策略:

参数选择:

计算访存比分析:

核心代码:

__shared__ float s_a[BM][BK];
__shared__ float s_b[BK][BN];
float r_c[TM][TN] = {0.0};

for (int bk = 0; bk < (K + BK - 1) / BK; bk++) {
    // 加载数据到Shared Memory
    FLOAT4(s_a[load_a_smem_m][load_a_smem_k]) = FLOAT4(a[load_a_gmem_addr]);
    FLOAT4(s_b[load_b_smem_k][load_b_smem_n]) = FLOAT4(b[load_b_gmem_addr]);
    __syncthreads();
    
    // 计算
    for (int k = 0; k < BK; k++) {
        for (int m = 0; m < TM; m++) {
            for (int n = 0; n < TN; n++) {
                r_c[m][n] += s_a[ty*TM+m][k] * s_b[k][tx*TN+n];
            }
        }
    }
    __syncthreads();
}

性能提升:

3.2 优化二:解决Bank Conflict

Bank Conflict问题:

V1版本的Bank Conflict:

  1. 矩阵A按行存储,读取列向量时产生冲突
  2. 每个线程读取连续8个数,需要2条LDS.128指令
  3. 同一warp不同线程的访存地址间隔,产生冲突

优化策略:

  1. 矩阵A转置存储: 改为[BK][BM]布局,按列存储
  2. 重新划分计算块: 将TM×TN分为两个TM/2×TN块
  3. 使用FLOAT4指令: 一次读取4个连续元素

优化后代码:

__shared__ float s_a[BK][BM];  // 转置存储
__shared__ float s_b[BK][BN];

float r_comp_a[TM];
float r_comp_b[TN];

for (int tk = 0; tk < BK; tk++) {
    FLOAT4(r_comp_a[0]) = FLOAT4(s_a[tk][ty * TM / 2]);
    FLOAT4(r_comp_a[4]) = FLOAT4(s_a[tk][ty * TM / 2 + BM / 2]);
    FLOAT4(r_comp_b[0]) = FLOAT4(s_b[tk][tx * TN / 2]);
    FLOAT4(r_comp_b[4]) = FLOAT4(s_b[tk][tx * TN / 2 + BN / 2]);
    
    for (int tm = 0; tm < TM; tm++) {
        for (int tn = 0; tn < TN; tn++) {
            r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
        }
    }
}

性能提升:

3.3 优化三:Double Buffering(双缓冲)

核心思想: 通过流水线化隐藏访存延迟,使计算和访存并行

Single Buffering问题:

Double Buffering原理:

实现要点:

  1. Shared Memory大小翻倍: s_a[2][BK][BM], s_b[2][BK][BN]
  2. 主循环从bk=1开始,第一次加载在循环前
  3. 最后一次计算在循环后
  4. 每次循环只需一次__syncthreads()
  5. 先加载下一次数据到寄存器,再计算,最后写入Shared Memory

核心代码结构:

// 第一次加载
FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);
s_a[0][...] = r_load_a[...];
s_b[0][...] = r_load_b[...];

for (int bk = 1; bk < (K + BK - 1) / BK; bk++) {
    int smem_sel = (bk - 1) & 1;
    int smem_sel_next = bk & 1;
    
    // 预取下一次数据到寄存器
    FLOAT4(r_load_a[0]) = FLOAT4(a[...]);
    FLOAT4(r_load_b[0]) = FLOAT4(b[...]);
    
    // 使用当前buffer计算
    for (int tk = 0; tk < BK; tk++) {
        // 计算逻辑
    }
    
    // 将预取数据写入下一个buffer
    s_a[smem_sel_next][...] = r_load_a[...];
    s_b[smem_sel_next][...] = r_load_b[...];
    
    __syncthreads();
}

// 最后一次计算
for (int tk = 0; tk < BK; tk++) {
    // 计算逻辑
}

性能提升:

四、cuBLAS性能对比

4.1 cuBLAS简介

cuBLAS是NVIDIA官方的BLAS库实现,支持:

4.2 cuBLAS的层次化设计

cuBLAS/CUTLASS采用三层tile结构:

  1. Thread Block Tile: Global Memory → Shared Memory
  2. Warp Tile: Shared Memory → Register
  3. Thread Tile: Register → CUDA Core计算

4.3 cublasSgemm使用

cublasHandle_t cublas_handle;
cublasCreate(&cublas_handle);
float cublas_alpha = 1.0;
float cublas_beta = 0;
cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, 
            N, M, K, &cublas_alpha, d_b, N, d_a, K, 
            &cublas_beta, d_c, N);

4.4 性能对比

cuBLAS性能:

各版本性能对比表:

优化方法 M=N=K=512 (GFLOPS) M=N=K=16384 (GFLOPS) 峰值利用率
Naive 1,323 1,792 11.5%
V1: Shared Memory 1,562 8,119 51.7%
V2: Bank Conflict 1,670 11,670 74.3%
V3: Double Buffer 2,063 12,658 80.6%
cuBLAS 4,864 12,932 82.4%

五、关键技术点总结

5.1 内存层次优化

Global Memory → Shared Memory:

Shared Memory → Register:

5.2 并行策略

线程映射:

访存合并:

5.3 流水线优化

Double Buffering:

5.4 资源限制

Shared Memory:

寄存器:

线程数:

六、性能分析与调优经验

6.1 性能瓶颈分析

朴素实现的瓶颈:

Shared Memory优化后:

进一步优化方向:

6.2 参数调优经验

BM, BN选择:

BK选择:

TM, TN选择:

6.3 Profiling工具使用

nsys (Nsight Systems):

ncu (Nsight Compute):

cuda-calculator:

七、代码示例与实现细节

7.1 索引计算

Global Memory索引:

// Block级别
int load_a_gmem_m = by * BM + load_a_smem_m;  // 全局行号
int load_a_gmem_k = bk * BK + load_a_smem_k;  // 全局列号

// Thread级别
int load_a_smem_m = tid >> 1;           // tid/2, Shared Memory行号
int load_a_smem_k = (tid & 1) << 2;     // (tid%2)*4, Shared Memory列号

Shared Memory索引:

// 计算时的索引
int comp_a_smem_m = ty * TM + m;        // 线程块内行号
int comp_b_smem_n = tx * TN + n;        // 线程块内列号

结果写回索引:

int store_c_gmem_m = by * BM + ty * TM + i;      // 全局行号
int store_c_gmem_n = bx * BN + tx * TN + j;      // 全局列号

7.2 向量化访存

FLOAT4宏定义:

#define FLOAT4(pointer) (reinterpret_cast<float4*>(&(pointer))[0])

使用示例:

// 一次读取4个float
FLOAT4(s_a[load_a_smem_m][load_a_smem_k]) = FLOAT4(a[load_a_gmem_addr]);

// 一次写入4个float
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][j]);

7.3 同步机制

Block内同步:

__syncthreads();  // 等待Block内所有线程

使用场景:

八、进阶优化方向

8.1 Warp级优化

Warp Shuffle指令:

Warp级同步:

__syncwarp();  // warp内同步,比__syncthreads()更轻量

8.2 Tensor Core加速

特点:

使用方式:

8.3 更多优化技术

Prefetching:

Loop Unrolling:

Async Copy:

九、总结与展望

9.1 优化路径总结

  1. 朴素实现 (11.5%): 功能正确但性能差
  2. Shared Memory (51.7%): 提高计算访存比
  3. Bank Conflict优化 (74.3%): 提高Shared Memory效率
  4. Double Buffering (80.6%): 隐藏访存延迟
  5. cuBLAS (82.4%): 工业级优化

9.2 关键收获

性能优化思路:

内存层次利用:

并行化策略:

9.3 实际应用建议

生产环境:

学习研究:

9.4 扩展阅读

相关主题:

参考资源:


附录:完整性能数据

V1版本(Shared Memory)

M N K =    512    512   1024, Performance =  1562.1563 Gflops
M N K =   1024   1024   1024, Performance =  5748.3952 Gflops
M N K =   4096   4096   1024, Performance =  7175.4698 Gflops
M N K =  16384  16384   1024, Performance =  8118.7715 Gflops

V2版本(Bank Conflict优化)

M N K =    512    512   1024, Performance =  1669.7479 Gflops
M N K =   1024   1024   1024, Performance =  6155.0281 Gflops
M N K =   4096   4096   1024, Performance = 10029.7885 Gflops
M N K =  16384  16384   1024, Performance = 11669.5995 Gflops

V3版本(Double Buffering)

M N K =    512    512   1024, Performance =  2062.9786 Gflops
M N K =   1024   1024   1024, Performance =  8119.7302 Gflops
M N K =   4096   4096   1024, Performance = 10320.6843 Gflops
M N K =  16384  16384   1024, Performance = 12658.1556 Gflops

cuBLAS

M N K =    512    512   1024, Performance =  4863.9646 Gflops
M N K =   1024   1024   1024, Performance =  8176.5614 Gflops
M N K =   4096   4096   1024, Performance = 10994.0128 Gflops
M N K =  16384  16384   1024, Performance = 12931.6086 Gflops

文章来源: 知乎 - 紫气东来
代码仓库: https://github.com/ifromeast/cuda_learning
测试硬件: Tesla V100-PCIE-32GB (FP32峰值15.7 TFLOPS)