通用矩阵乘法执行

通用矩阵乘法执行
使用两个手工实现的纯粹GEMM和分块GEMM的例子来解释矩阵分块乘法的原理和性能影响, 可以看到性能差距接近53倍. 按照测试的A10 GPU峰值FP32算力31TFFLOPS来算, 最朴素的算法由于访存效率的问题, 浮点算力仅为峰值的1%。
# ./naive 
AveragePerformance  0.2336 Tflops

# ./block 
AveragePerformance  10.7669 Tflops
1. 纯粹GEMM
最简单的矩阵乘法如下:
#define OFFSET(row, col, stride) ((row) * (stride) + (col))

__global__ void basic_gemm(
    float *  A, float *  B, float *  C,
    const int M, const int N, const int K) {

    int _x = blockIdx.x * blockDim.x + threadIdx.x;
    int _y = blockIdx.y * blockDim.y + threadIdx.y;
    if (_x < M && _y < N) {
        float sum = 0.0;
        for (int k = 0; k < K; k++) {
            sum +=A[OFFSET(_x, k, K)] *  B[OFFSET(k , _y, N)];
        }
        C[OFFSET(_x, _y, N)] = sum;
    }
}
在A10上测试其FLOS大概仅有233GFlops。
int main() {
    const int M = 4096;
    const int K = 1024;
    const int N = 4096;
    const int ITER = 100;

    dim3 gridDim(ceil(M/32), ceil(N/32), 1);
    dim3 blockDim(32, 32, 1);

    float *d_a, *d_b, *d_c ;
    cudaMalloc(&d_a, M * K * sizeof(float));
    cudaMalloc(&d_b, K * N * sizeof(float));
    cudaMalloc(&d_c, M * N * sizeof(float));

    cudaEvent_t start, end;
    cudaEventCreate(&start);
    cudaEventCreate(&end);

    cudaEventRecord(start);
    for (int i = 0; i < ITER; i++)
        basic_gemm<<<gridDim, blockDim>>>(d_a, d_b, d_c, M, N, K);
    cudaEventRecord(end);
    cudaEventSynchronize(end);

    float msec;
    cudaEventElapsedTime(&msec, start, end);
    
    long workload =  long(M) * N * K * 2 * ITER;
    double avg_Tflops = ((double)workload / 1e12 ) / (double(msec)/ 1e3);
    printf("AveragePerformance  %6.4lf Tflops\n",avg_Tflops);

    cudaFree(d_a);
    cudaFree(d_b);
    cudaFree(d_c);
}
2. 分块GEMM
相关的代码画了一个容易理解的示意图,如图6-37所示。
图6-37 分块通用矩阵乘法示例代码示意图
代码如下:
__global__ void block2d_gemm(const float *A, const float *B, float *C,
                            int M, int N, int K) {

  const int BM = 128;
  const int BN = 128;
  const int BK = 8;
  const int TM = 8;
  const int TN = 8;

  const uint cRow = blockIdx.y;
  const uint cCol = blockIdx.x;

  const uint totalResultsBlocktile = BM * BN;
//线程负责计算块中的TM*TN元素
    const uint numThreadsBlocktile = totalResultsBlocktile / (TM * TN);

//BN/TN是跨越一列的线程数
  const int threadCol = threadIdx.x % (BN / TN);
  const int threadRow = threadIdx.x / (BN / TN);
 //在smem中为当前块分配空间
   __shared__ float As[BM * BK];
  __shared__ float Bs[BK * BN];
 //将块图块移动到A行和B列的开头
   A += cRow * BM * K;
  B += cCol * BN;
  C += cRow * BM * N + cCol * BN;
  //计算此线程将加载到SMEM中的索引
   const uint innerRowA = threadIdx.x / BK;
  const uint innerColA = threadIdx.x % BK;
   //计算单个块在单个步骤中加载的As行数
    const uint strideA = numThreadsBlocktile / BK;
  const uint innerRowB = threadIdx.x / BN;
  const uint innerColB = threadIdx.x % BN;
  //As和Bs,希望每个负载都跨越整个列宽,以便更好地进行GMEM合并
//(与跨越整个行宽和跨列迭代相反)
const uint strideB = numThreadsBlocktile / BN;
//为注册表文件中的结果分配线程本地缓存
  float threadResults[TM * TN] = {0.0};
//为As和Bs注册缓存
  float regM[TM] = {0.0};
  float regN[TN] = {0.0};
  //最外层的分块平铺环
  for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
//填充SMEM缓存
    for (uint loadOffset = 0; loadOffset < BM; loadOffset += strideA) {
      As[(innerRowA + loadOffset) * BK + innerColA] =
          A[(innerRowA + loadOffset) * K + innerColA];
    }
    for (uint loadOffset = 0; loadOffset < BK; loadOffset += strideB) {
      Bs[(innerRowB + loadOffset) * BN + innerColB] =
          B[(innerRowB + loadOffset) * N + innerColB];
    }
    __syncthreads();
  //超前分块
    A += BK;     //将BK列向右移动
    B += BK * N; //向下移动BK行

    // 计算每个线程的结果
    for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
// 进入内存的块
      for (uint i = 0; i < TM; ++i) {
        regM[i] = As[(threadRow * TM + i) * BK + dotIdx];
      }
      for (uint i = 0; i < TN; ++i) {
        regN[i] = Bs[dotIdx * BN + threadCol * TN + i];
      }
      for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
        for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
          threadResults[resIdxM * TN + resIdxN] +=
              regM[resIdxM] * regN[resIdxN];
        }
      }
    }
    __syncthreads();
  }

  // 写出结果
  for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
    for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
      C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN] =
         threadResults[resIdxM * TN + resIdxN] +
         C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN];
    }
  }
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/800209.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

交易柜台系统技术名词

目录交互示意图柜台API前置机行情和交易接口生产环境服务器托管(Co-location)什么是高频交易 (HFT)?交互示意图 程序化交易用户是如何与期货公司、交易所进行信息交互的?柜台 依据国内监管要求,客户无法直连交易所系统,中间必须经过期货公司(Broker)的系统,这便是柜台系…

全网最适合入门的面向对象编程教程:50 Python函数方法与接口-接口和抽象基类

在Python中,接口和抽象基类(Abstract Base Classes, ABCs)都用于定义类的结构和强制子类实现特定的方法,Python 没有内建的接口机制,但可以通过抽象基类(ABC)来模拟接口的行为。全网最适合入门的面向对象编程教程:50 Python 函数方法与接口-接口和抽象基类摘要: 在 Py…

javafx jlink 遇到的非模块化的依赖打包报错“模块异常”的问题和处理

javafx jlink 遇到的问题和处理 简介 javafx:jlink 是 javafx-maven-plugin 插件中的一个目标,用于创建一个自包含的 JavaFX 应用程序运行时映像。这个目标利用 Java 的 jlink 工具来生成一个包含应用程序及其所有依赖的定制化运行时映像,从而简化部署和分发。创建自包含运行…

The minimum required version for Powerlevel10k is 5.1

目录一、背景二、原因三、解决1、安装 ZSH 最新版本2、效果3、下载了还是显示 ZSH 版本为 5.0.2 怎么办 一、背景 安装 ZSH 主题 Powerlevel10k 时报错:You are using ZSH version 5.0.2. The minimum required version for Powerlevel10k is 5.1. Type echo $ZSH_VERSION to …

Python pycryptodome类库使用学习总结

AES数据加解密 以下代码生成一个新的AES-128密钥,并将一段数据加密到一个文件中。我们使用 CTR 模式(这是一种 经典操作模式, 简单但不再推荐)。 仅使用CTR,接收者无法检测到密文(即加密数据)在传输过程中是否被修改。为了应对这种风险,例中还附加了一个MAC身份验证标签…

电脑设置系统不自动更新

1、win + R 2、计算机\HKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\WindowsUpdate\UX\StateVariables 3、右边空白处右击 -> 新建 -> DWORD值,命名为FlightSettingsMaxPauseDays,点击基数选择十进制,数值设置为9999(表示不更新的天数)

同花顺--涨停板改变颜色

复制以下代码 IF(C>=REF(C,1)*1.095 AND C=H) RETURN "涨停"; 然后进行操作: 1、打开同花顺软件,右击K线,单击修改K线2、光标挪到代码首行行首,回车换行3、粘贴一下4、点击设置标志5、命名为涨停,选颜色,填充打勾6、点击确定

关于零值和nil

1. 零值 零值是指当你声明变量(分配内存)并未显式初始化时,始终为你的变量自动设置一个默认初始值的策略。 对于值类型:布尔类型为 false, 数值类型为 0,字符串为 "",数组和结构会递归初始化其元素或字段,即其初始值取决于元素或字段。 对于引用类型: 均为 n…

利用AutoGpt将任何模型支持o1模型的推理实现

利用AutoGpt将任何模型支持o1模型的推理实现 相信大家都对于OpenAI最新出的o1模型都非常关注,它已经能通过推理让回复的效果更加理想, 但是目前o1的限制太大,而且使用o1至少也是需要购买OpenAI官方的会员价格也在20美刀(好贵!!),于是乎社区出现非常多相似的实现,通过更…

C语言类型与强制类型转换

目录类型关键字sizeof如何理解强制类型转化不同类型的0null字符设备(补充) char有有符号和无符号两种类型,字符是无符号类型.(补充) getchar的返回值为什么是int键盘输入的内容,以及往显示器中打印的内容,都是字符 --> 键盘/显示器称为字符设备 类型C语言为何有类型? 让我们…

如何在 ASP.NET Core Web API 方法执行前后 “偷偷“ 作一些 “坏“ 事?初识 ActionFilterAttribute

ActionFilterAttribute 是一种作用于控制器 Action 方法的特性(Attribute),通过它,你可以在操作执行前后、异常处理时等不同的阶段插入自定义逻辑。 比如在执行操作方法之前修改请求参数、记录日志、进行权限验证等操作,在执行操作方法之后发送邮件、同步数据等等。 本文主…

访问Github卡顿甚至进不去的解决办法(适用于Windows)

本文使用Watt Tookit(原Steam++)解决了Github在国内访问速度卡顿甚至无反应的问题,通过NDM和镜像网站实现Github大文件高速下载。本文首发自个人博客:点我查看 一、前言 Github 是全球知名的开源宝库,但是对国内用户并不友好。当我们在浏览器中输入www.github.com时,如果…