张量矩阵乘法分块乘法概述
介绍一下矩阵计算相关的内容, 从最基本的算法,到Cutlass这些线性代数模版库, 特别是Layout代数相关的内容,再逐渐细化到一些硬件实现访存优化和一些算子融合。
6.3.1 GEMM概述
1. GEMM定义
对于一个矩阵乘法, 定义如下:
(6-1)
一个矩阵乘法定义,如图6-26所示。
图6-26 一个矩阵乘法定义图示
2. 内积形式
因此,可以构建一个最简单的算法。
for (int i = 0; i < M; ++i)
for (int j = 0; j < N; ++j)
for (int k = 0; k < K; ++k)
C[i][j] += A[i][k] * B[k][j];
for (int j = 0; j < N; ++j)
for (int k = 0; k < K; ++k)
C[i][j] += A[i][k] * B[k][j];
这种乘法是也被称为矩阵乘法的内积形式。
(6-2)
可以注意整个过程中随着循环, B矩阵的乘法空间局部性很差,存在多次访问, 因此尽量需要缓存一些数据来避免缓存颠簸(cache thrashing)。
3. 外积形式
换一种思路, 如果按照如下方法构建乘法
其中,
(6-3)
即可以把K维度放在最外面, 这样A和B矩阵,都可以按照列和行整个一块的读取。
for (int k = 0; k < K; ++k) //外环处的dim-k
//C_i的外积
for (int i = 0; i < M; ++i)
for (int j = 0; j < N; ++j)
C[i][j] += A[i][k] * B[k][j];
//C_i的外积
for (int i = 0; i < M; ++i)
for (int j = 0; j < N; ++j)
C[i][j] += A[i][k] * B[k][j];