大模型自定义算子优化方案学习笔记:CUDA算子定义、算子编译、正反向梯度实现

01算子优化的意义

随着大模型应用的普及以及算力紧缺,下一步对于计算性能的追求一定是技术的核心方向。因为目前大模型的计算逻辑是由一个个独立的算子或者说OP正反向求导实现的,底层往往调用的是GPU提供的CUDA的驱动程序。如果不能对于整个计算过程学习并了解,对于性能优化领域无非是隔靴搔痒,今天也是抽一点时间读了下网上的一些文档和CUDA的文档,整理了学习材料。

首先说下为什么要自定义算子,无非是两个原因,

(1)TF、PyTorch等提供的原生算子不满足需求,需要通过底层接口,比如CUDA层更灵活的实现个性化算子

(2)目前提供的算子性能不足,没有很好的利用到GPU的并行计算优势,有优化空间

接着性能优化的问题说,因为GPU与CPU最大的区别是计算单元占据了绝大部分的体积(图中绿色部分),而控制单元较少。

自定义手写算子可以更好地利用绿色的计算单元的并行化优势。大的思路是Grid包含Block,Block包含Thread。于是首先算子需要把计算逻辑拆分成Thread,让程序可以并行化的运行起来,然后有机的管理各个Block的执行节奏,解决好异步和同步问题,就可以让芯片的计算效率最大化。

Grid of Thread Blocks

02实现流程

整个自定义算子优化其实可以分为CUDA算子定义、算子编译、正方向梯度实现几个步骤。

1、CUDA算子定义

需要定义以下几个文件:

(1)function.cpp:python层和CUDA层的衔接,实现手写的C++ CUDA算子可以被python代码调用

(2)function.h:CUDA函数声明文件

(3)function.cu:CUDA函数的逻辑实现

这里比较核心的文件就是.cu文件,构建的时候主要做两个事:一个是建设Kernel函数,因为只有Kernel函数是在GPU端执行,执行完之后要将控制权给到控制函数,这里要控制好异步、同步的问题。二是在kernel函数中需要通过循环函数定义每个Thread以及每个Block的工作,真正让计算并行化

.cpp文件可以通过pybind函数实现,这个函数主要解决的是C++代码和Python绑定的问题,项目地址:GitHub - pybind/pybind11: Seamless operability between C++11 and Python

2.编译和执行

import torch
from torch.utils.cpp_extension import load
cuda_module = load(name="function",extra_include_paths=["include"],sources=["function.cpp", "function.cu"],verbose=True)

接着就是执行过程中的编译,通过load函数底层会执行JIT(Just in time)的动态编译模式调用.cpp和.cu文件,底层运行的是Ninjia编译器,通过调用nvcc编译.so文件

[1/2] nvcc -c function.cu -o function.cuda.o
[2/3] c++ -c function.cpp -o function.o
[3/3] c++ function.o function.cuda.o -shared -o function.so

3.正反向梯度实现

以上两步实现了自定义算子的逻辑,可以通过手写CUDA算子并在python框架层调用,如果要满足建模需求,还需要实现正方向梯度。具体的做法是在建模的函数中实现forward和backward函数,定义导数作为输出。

以上大概就是手写算子优化的简单流程,仅当学习笔记。

参考:

【1】熬了几个通宵,我写了份CUDA新手入门代码 - 知乎

【2】CUDA C++ Programming Guide

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/283751.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

企业计算机服务器中了halo勒索病毒如何解密,halo勒索病毒恢复流程

网络技术的不断发展与应用,为企业的生产运营提供了极大便利,越来越多的企业使用数据库存储企业的重要数据,方便工作与生产,但网络是一把双刃剑,网络安全威胁一直存在,并且网络威胁的手段也在不断升级。在本…

【【UART 传输数据实验】】

UART 传输数据实验 通信方式在日常的应用中一般分为串行通信(serial communication)和并行通信(parallel communication)。 我们再来了解下串行通信的特点。串行通信是指数据在一条数据线上,一比特接一比特地按顺序传…

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)控件的部分公共属性和事件

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)控件的部分公共属性和事件 一、操作环境 操作系统: Windows 10 专业版 IDE:DevEco Studio 3.1 SDK:HarmonyOS 3.1 二、公共属性 常用的公共属性有: 宽(with)、高(height)、…

【STM32】STM32学习笔记-对射式红外传感器计次 旋转编码器计次(12)

00. 目录 文章目录 00. 目录01. NVIC相关函数1.1 NVIC_PriorityGroupConfig函数1.2 NVIC_PriorityGroup类型1.3 NVIC_Init函数1.4 NVIC_InitTypeDef类型 02. 外部中断相关API2.1 GPIO_EXTILineConfig2.2 EXTI_Init2.3 EXTI_GetITStatus2.4 EXTI_ClearITPendingBit2.5 中断回调函…

详细解析POI 和 EasyExcel

在数据量需要被批量导入、导出的时候,就可以使用POI和easyExcel 常用场景: 1、将用户信息导出为excel表格(导出数据…) 2、将Excel表中的信息录入到网站数据库(习题上传…)大大减轻网站录入量!开发中经常会设计到excel的处理&am…

NNDL 作业11 LSTM [HBU ]

目录 习题6-4 推导LSTM网络中参数的梯度, 并分析其避免梯度消失的效果 >LSTM前向传播 >反向传播 求梯度 >梯度消失和梯度爆炸怎么来的? >关键点:LSTM如何缓解梯度消失? 习题6-3P 编程实现下图LSTM运行过程 1…

uniapp中uni-data-select下拉框组件如何去除边框?

在目录中找到文件夹。 找到下拉框组件文件夹 注释该文件夹以下代码就能实现下拉框不带边框。

主从reactor多线程实现

现场模型图片,从网上找的 出于学习的目的实现的,如有不对的地方欢迎留言知道,简单实现了http的请求,可通过postman进行访问 启动项目: 返回数据示例 postman请求 附上源码,有问题直接看源码吧

java21特性学习

jdk21下载地址 JDK21文件 JDK21是javaSE平台最新的长期支持版本。 Java SE Java Archive | Oracle JDK21版本说明 JDK 21 Release Notes, Important Changes, and Information JavaSE 版本字符串格式 Version-String Format JavaSE平台采用了基于时间的发布模型,JDK每六个…

晚期食管癌肿瘤治疗线程分类

文章目录 1、肿瘤治疗的线数1.1 基础概念1.2 线程定义1.3 如何计算治疗线数 2 食管癌治疗指南2.1 食管癌诊疗指南2.1 CSCO 本文前半部分主要来源于参考文件1,其余部分来源于官方指南。无原创内容,全部为摘要。 1、肿瘤治疗的线数 1.1 基础概念 抗肿瘤药…

前后端传参:掌握FormData并解决form-data类型传参问题

目录 第一章 解决问题过程 第二章 对form-data的理解 2.1 使用场景 2.2 了解formData对象的创建与使用 2.3 formData常用方法 2.3.1 构造函数 2.3.2 获取数据 2.3.3 添加数据 2.3.4 修改数据 2.3.5 检查是否有该数据 2.3.6 删除数据 2.3.7 遍历formData里的键值对…

Charles 安装与激活

步骤 1:购买 Charles 许可证 访问 Charles 官方网站:https://www.charlesproxy.com/ 在网站上查找并选择 “Buy” 或类似的选项。 选择适合你需求的许可证类型,填写相关信息并完成购买。 如果不想购买可点击此链接Charles 步骤 2&#xff…