通用矩阵乘法执行
使用两个手工实现的纯粹GEMM和分块GEMM的例子来解释矩阵分块乘法的原理和性能影响, 可以看到性能差距接近53倍. 按照测试的A10 GPU峰值FP32算力31TFFLOPS来算, 最朴素的算法由于访存效率的问题, 浮点算力仅为峰值的1%。
# ./naive
AveragePerformance 0.2336 Tflops
# ./block
AveragePerformance 10.7669 Tflops
AveragePerformance 0.2336 Tflops
# ./block
AveragePerformance 10.7669 Tflops
1. 纯粹GEMM
最简单的矩阵乘法如下:
#define OFFSET(row, col, stride) ((row) * (stride) + (col))
__global__ void basic_gemm(
float * A, float * B, float * C,
const int M, const int N, const int K) {
int _x = blockIdx.x * blockDim.x + threadIdx.x;
int _y = blockIdx.y * blockDim.y + threadIdx.y;
if (_x < M && _y < N) {
float sum = 0.0;
for (int k = 0; k < K; k++) {
sum +=A[OFFSET(_x, k, K)] * B[OFFSET(k , _y, N)];
}
C[OFFSET(_x, _y, N)] = sum;
}
}
__global__ void basic_gemm(
float * A, float * B, float * C,
const int M, const int N, const int K) {
int _x = blockIdx.x * blockDim.x + threadIdx.x;
int _y = blockIdx.y * blockDim.y + threadIdx.y;
if (_x < M && _y < N) {
float sum = 0.0;
for (int k = 0; k < K; k++) {
sum +=A[OFFSET(_x, k, K)] * B[OFFSET(k , _y, N)];
}
C[OFFSET(_x, _y, N)] = sum;
}
}
在A10上测试其FLOS大概仅有233GFlops。
int main() {
const int M = 4096;
const int K = 1024;
const int N = 4096;
const int ITER = 100;
dim3 gridDim(ceil(M/32), ceil(N/32), 1);
dim3 blockDim(32, 32, 1);
float *d_a, *d_b, *d_c ;
cudaMalloc(&d_a, M * K * sizeof(float));
cudaMalloc(&d_b, K * N * sizeof(float));
cudaMalloc(&d_c, M * N * sizeof(float));
cudaEvent_t start, end;
cudaEventCreate(&start);
cudaEventCreate(&end);
cudaEventRecord(start);
for (int i = 0; i < ITER; i++)
basic_gemm<<<gridDim, blockDim>>>(d_a, d_b, d_c, M, N, K);
cudaEventRecord(end);
cudaEventSynchronize(end);
float msec;
cudaEventElapsedTime(&msec, start, end);
long workload = long(M) * N * K * 2 * ITER;
double avg_Tflops = ((double)workload / 1e12 ) / (double(msec)/ 1e3);
printf("AveragePerformance %6.4lf Tflops\n",avg_Tflops);
cudaFree(d_a);
cudaFree(d_b);
cudaFree(d_c);
}
int main() {
const int M = 4096;
const int K = 1024;
const int N = 4096;
const int ITER = 100;
dim3 gridDim(ceil(M/32), ceil(N/32), 1);
dim3 blockDim(32, 32, 1);
float *d_a, *d_b, *d_c ;
cudaMalloc(&d_a, M * K * sizeof(float));
cudaMalloc(&d_b, K * N * sizeof(float));
cudaMalloc(&d_c, M * N * sizeof(float));
cudaEvent_t start, end;
cudaEventCreate(&start);
cudaEventCreate(&end);
cudaEventRecord(start);
for (int i = 0; i < ITER; i++)
basic_gemm<<<gridDim, blockDim>>>(d_a, d_b, d_c, M, N, K);
cudaEventRecord(end);
cudaEventSynchronize(end);
float msec;
cudaEventElapsedTime(&msec, start, end);
long workload = long(M) * N * K * 2 * ITER;
double avg_Tflops = ((double)workload / 1e12 ) / (double(msec)/ 1e3);
printf("AveragePerformance %6.4lf Tflops\n",avg_Tflops);
cudaFree(d_a);
cudaFree(d_b);
cudaFree(d_c);
}
2. 分块GEMM
相关的代码画了一个容易理解的示意图,如图6-37所示。
图6-37 分块通用矩阵乘法示例代码示意图
代码如下:
__global__ void block2d_gemm(const float *A, const float *B, float *C,
int M, int N, int K) {
const int BM = 128;
const int BN = 128;
const int BK = 8;
const int TM = 8;
const int TN = 8;
const uint cRow = blockIdx.y;
const uint cCol = blockIdx.x;
const uint totalResultsBlocktile = BM * BN;
int M, int N, int K) {
const int BM = 128;
const int BN = 128;
const int BK = 8;
const int TM = 8;
const int TN = 8;
const uint cRow = blockIdx.y;
const uint cCol = blockIdx.x;
const uint totalResultsBlocktile = BM * BN;
//线程负责计算块中的TM*TN元素
const uint numThreadsBlocktile = totalResultsBlocktile / (TM * TN);
//BN/TN是跨越一列的线程数
const int threadCol = threadIdx.x % (BN / TN);
const int threadRow = threadIdx.x / (BN / TN);
//在smem中为当前块分配空间
__shared__ float As[BM * BK];
__shared__ float Bs[BK * BN];
//将块图块移动到A行和B列的开头
A += cRow * BM * K;
B += cCol * BN;
C += cRow * BM * N + cCol * BN;
//计算此线程将加载到SMEM中的索引
const uint innerRowA = threadIdx.x / BK;
const uint innerColA = threadIdx.x % BK;
//计算单个块在单个步骤中加载的As行数
const uint strideA = numThreadsBlocktile / BK;
const uint innerRowB = threadIdx.x / BN;
const uint innerColB = threadIdx.x % BN;
//As和Bs,希望每个负载都跨越整个列宽,以便更好地进行GMEM合并
const uint numThreadsBlocktile = totalResultsBlocktile / (TM * TN);
//BN/TN是跨越一列的线程数
const int threadCol = threadIdx.x % (BN / TN);
const int threadRow = threadIdx.x / (BN / TN);
//在smem中为当前块分配空间
__shared__ float As[BM * BK];
__shared__ float Bs[BK * BN];
//将块图块移动到A行和B列的开头
A += cRow * BM * K;
B += cCol * BN;
C += cRow * BM * N + cCol * BN;
//计算此线程将加载到SMEM中的索引
const uint innerRowA = threadIdx.x / BK;
const uint innerColA = threadIdx.x % BK;
//计算单个块在单个步骤中加载的As行数
const uint strideA = numThreadsBlocktile / BK;
const uint innerRowB = threadIdx.x / BN;
const uint innerColB = threadIdx.x % BN;
//As和Bs,希望每个负载都跨越整个列宽,以便更好地进行GMEM合并
//(与跨越整个行宽和跨列迭代相反)
const uint strideB = numThreadsBlocktile / BN;
//为注册表文件中的结果分配线程本地缓存
float threadResults[TM * TN] = {0.0};
float threadResults[TM * TN] = {0.0};
//为As和Bs注册缓存
float regM[TM] = {0.0};
float regN[TN] = {0.0};
//最外层的分块平铺环
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
float regM[TM] = {0.0};
float regN[TN] = {0.0};
//最外层的分块平铺环
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
//填充SMEM缓存
for (uint loadOffset = 0; loadOffset < BM; loadOffset += strideA) {
As[(innerRowA + loadOffset) * BK + innerColA] =
A[(innerRowA + loadOffset) * K + innerColA];
}
for (uint loadOffset = 0; loadOffset < BK; loadOffset += strideB) {
Bs[(innerRowB + loadOffset) * BN + innerColB] =
B[(innerRowB + loadOffset) * N + innerColB];
}
__syncthreads();
//超前分块
A += BK; //将BK列向右移动
B += BK * N; //向下移动BK行
// 计算每个线程的结果
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
for (uint loadOffset = 0; loadOffset < BM; loadOffset += strideA) {
As[(innerRowA + loadOffset) * BK + innerColA] =
A[(innerRowA + loadOffset) * K + innerColA];
}
for (uint loadOffset = 0; loadOffset < BK; loadOffset += strideB) {
Bs[(innerRowB + loadOffset) * BN + innerColB] =
B[(innerRowB + loadOffset) * N + innerColB];
}
__syncthreads();
//超前分块
A += BK; //将BK列向右移动
B += BK * N; //向下移动BK行
// 计算每个线程的结果
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
// 进入内存的块
for (uint i = 0; i < TM; ++i) {
regM[i] = As[(threadRow * TM + i) * BK + dotIdx];
}
for (uint i = 0; i < TN; ++i) {
regN[i] = Bs[dotIdx * BN + threadCol * TN + i];
}
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
threadResults[resIdxM * TN + resIdxN] +=
regM[resIdxM] * regN[resIdxN];
}
}
}
__syncthreads();
}
// 写出结果
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN] =
threadResults[resIdxM * TN + resIdxN] +
C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN];
}
}
}
for (uint i = 0; i < TM; ++i) {
regM[i] = As[(threadRow * TM + i) * BK + dotIdx];
}
for (uint i = 0; i < TN; ++i) {
regN[i] = Bs[dotIdx * BN + threadCol * TN + i];
}
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
threadResults[resIdxM * TN + resIdxN] +=
regM[resIdxM] * regN[resIdxN];
}
}
}
__syncthreads();
}
// 写出结果
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN] =
threadResults[resIdxM * TN + resIdxN] +
C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN];
}
}
}