英伟达结构化剪枝工具Nvidia Apex Automatic Sparsity [ASP](2)——代码分析

伟达结构化剪枝工具Nvidia Apex Automatic Sparsity [ASP](2)——代码分析

ASP整个模块的结果如下:

.
├── COPYRIGHT
├── README.md
├── __init__.py
├── asp.py
├── permutation_lib.py
├── permutation_search_kernels
│   ├── CUDA_kernels
│   │   └── permutation_search_kernels.cu
│   ├── __init__.py
│   ├── call_permutation_search_kernels.py
│   ├── channel_swap.py
│   ├── exhaustive_search.py
│   └── permutation_utilities.py
├── permutation_tests
│   ├── README.md
│   ├── ablation_studies.sh
│   ├── permutation_test.py
│   ├── runtime_table.sh
│   └── unstructured_study.sh
├── sparse_masklib.py
└── test├── checkpointing_test_part1.py├── checkpointing_test_part2.py├── checkpointing_test_reference.py├── test_permutation_application.py└── toy_problem.py

共包含三个主要文件:

  • asp.py
  • permutation_lib. py
  • sparse_masklib.py

以及三个主要目录

  • permutation_search_kernels
  • permutation_tests
  • test

其中目录test用于展示一些具体的实例,目录permutation_tests是一个单独的模块,用于复现论文中的实验,这两个目录不用关注。如果不需要使用通道置换算法的话,目录permutation_search_kernels和文件permutation_lib.py也不需要关注。

因此,ASP源代码中最主要的还是asp.py文件和sparse_masklib.py文件,如果需要使用通道置换算法的话,可以在此基础上探询一下permutation_search相关的算法和代码实现。

asp.py文件

ASP类

asp.py中主定义了ASP类,其成员函数定义了init_model_for_pruninginit_optimizer_for_pruningcompute_sparse_masksalready_init_asp_modelrestore_pruned_weightsis_sparsity_enabledprune_trained_modelset_permutation_saving_params八个静态方法,分别用于对模型、优化器进行稀疏初始化、计算稀疏mask、检查模型是否已经进行稀疏初始化,检查模型是否进行了稀疏化,恢复模型的权重以及为通道设置算法设置参数。其中最主要的是prune_trained_model及其调用的init_model_for_pruninginit_optimizer_for_pruningcompute_sparse_masks三个方法。

成员变量
    __model = None                           	# 待处理的模型__verbosity = 0							 	# 输出信息的详细程度__optimizer = None							# 待处理的优化器__sparse_parameters = []					# 用于保存稀疏参数信息__calculate_mask = None						# 一个函数指针,能够通过传入的tensor的shape为tensor生成相应的mask__allow_permutation = True					# 是否需要开启通道置换算法__all_parameters = []						# 用于保存模型中所有参数的信息__save_permutation_graph = False			# 是否保存通道置换的graph__permutation_output_dir = ''				# 通道置换信息的输出目录
成员函数
  • prune_trained_model

prune_trained_model是用法介绍中需要在模型训练文件中需要添加的两行代码之一,也是ASP模块的使用入口:

@classmethod
def prune_trained_model(cls, model, optimizer):# add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks)cls.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d, torch.nn.MultiheadAttention], allow_recompute_mask=False)cls.init_optimizer_for_pruning(optimizer)cls.compute_sparse_masks()

prune_trained_model方法接受两个参数,分别是需要训练后的模型和优化器。

该方法中又分别调用了三个方法:首先使用init_model_for_pruninginit_optimizer_for_pruning方法分别对模型和优化器中的权重进行分析和初始化准备工作(为模型添加mask buffer),并通过compute_sparse_masks方法为每个权重计算生成对应的稀疏mask。

  • init_model_for_pruning
def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d",verbosity=3,whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.MultiheadAttention], allowed_layer_names=None, disallowed_layer_names=[],allow_recompute_mask=False, custom_layer_dict={},allow_permutation=True):assert (cls.__model is None), "ASP has been initialized already."cls.__model = modelcls.__verbosity = verbositycls.__allow_permutation = allow_permutationif isinstance(mask_calculator, str):def create_mask_from_pattern(param):return create_mask(param, mask_calculator).bool()cls.__calculate_mask = create_mask_from_patternelse:cls.__calculate_mask = mask_calculator #user defined function# function to extract variables that will be sparsified. # idea is that you will add one of these functions for each module type that can be sparsified.if torchvision_imported:print("[ASP] torchvision is imported, can work with the MaskRCNN/KeypointRCNN from torchvision.")torchvision_version = str(torchvision.__version__)torchvision_version_major = int(torchvision_version.split('.')[0])torchvision_version_minor = int(torchvision_version.split('.')[1])if torchvision_version_major == 0 and torchvision_version_minor < 12:sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ['weight'], torch.nn.MultiheadAttention: ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight'], torchvision.ops.misc.Conv2d: ['weight']}else:    # Torchvision remove APIs that were deprecated before 0.8 (#5386) in 0.12.0, torchvision.ops.misc.Conv2d is removedsparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ['weight'], torch.nn.MultiheadAttention: ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']}else:sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ['weight'], torch.nn.MultiheadAttention: ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']}if custom_layer_dict: # Update default list to include user supplied custom (layer type : parameter tensor), make sure this tensor type is something ASP knows how to prunesparse_parameter_list.update(custom_layer_dict)whitelist += list(custom_layer_dict.keys())for module_type in whitelist:assert (module_type in sparse_parameter_list), "Module %s :: Don't know how to sparsify module." % module.dtype()

先看看官方给出的注释:

Call this method to modify your model to take advantage of sparse matrix multiplication.

Note that this call alone only augments the model with additional buffers needed for sparse MMA, it does not enable use of sparse MMA.

注释指出init_model_for_pruning方法仅仅为模型添加了额外的mask buffer,并没有实际上启用sparse MMA.

参数列表

          model                    The modelmask_calculator          Either callable that computes mask given a tensor OR pattern string for sparse mask lib.verbosity                Integer controling verbosity level.0 -> Only errors.1 -> Errors and warnings.2 -> Errors, warnings and info.3 -> Errors, warnings, info and debug.whitelist                Module types approved for sparsity.allowed_layer_names      If not None, only layer names that appear in this list are considered for sparsity.disallowed_layer_names   If not [], only layer names that do not appear in this list are considered for sparsity.allow_recompute_mask     If True, stores pruned values so that dense weights can be restored.Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage.custom_layer_dict        Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']}allow_permutation        If True, allow the input channel permutation to ease the influence of weight pruning.

init_model_for_pruning方法主要做了这样几件事情:

  1. 使用传入的参数对静态类ASP进行初始化,以便后续的处理
cls.__model = model
cls.__verbosity = verbosity
cls.__allow_permutation = allow_permutation
  1. 设置了一个函数指针,用来为特定的tensor生成sparse mask。
if isinstance(mask_calculator, str):def create_mask_from_pattern(param):return create_mask(param, mask_calculator).bool()cls.__calculate_mask = create_mask_from_pattern
else:cls.__calculate_mask = mask_calculator #user defined function""" returns a sparse mask """
def create_mask(tensor, pattern="m4n2_1d", density=0.5):# Reshape tensor and mask.shape = tensor.shapettype = tensor.type()t = tensor.float().contiguous()# 1d-tensorif len(shape) == 1:t = t.view(1, shape[0])func = getattr(sys.modules[__name__], pattern, None)mask = func(t, density)return mask.view(shape).type(ttype)# 2d-tensor (K, C)elif len(shape) == 2:# lineart = t.view(shape[0], shape[1])func = getattr(sys.modules[__name__], pattern, None)mask = func(t, density)return mask.view(shape).type(ttype)# 3d-tensor (K, C, R)elif len(shape) == 3:# 1d convst = t.permute(0,2,1).contiguous().view(shape[0]*shape[2], shape[1])func = getattr(sys.modules[__name__], pattern, None)mask = func(t, density)mask = mask.view(shape[0], shape[2], shape[1]).permute(0,2,1).contiguous()     return mask.view(shape).type(ttype)# 4d-tensor (K, C, R, S)elif len(shape) == 4:"""# transformers (bmm)t = t.view(shape[0]*shape[1]*shape[2], shape[3])func = getattr(sys.modules[__name__], pattern, None)mask = func(t, density)return mask.view(shape).type(ttype)"""# 2d convst = t.permute(2,3,0,1).contiguous().view(shape[2]*shape[3]*shape[0], shape[1])func = getattr(sys.modules[__name__], pattern, None)mask = func(t, density)mask = mask.view(shape[2], shape[3], shape[0], shape[1]).permute(2,3,0,1).contiguous()      return mask.view(shape).type(ttype)def m4n2_1d(mat, density):return mn_1d_best(mat, 4, 2)def mn_1d_best(matrix, m, n):# Find all possible patterns.patterns = compute_valid_1d_patterns(m,n).cuda()# Find the best m:n pattern (sum of non-masked weights).mask = torch.cuda.IntTensor(matrix.shape).fill_(1).view(-1,m)mat,shape = reshape_1d(matrix,m)pmax = torch.argmax(torch.matmul(mat.abs(),patterns.t()), dim=1)mask[:] = patterns[pmax[:]]mask = mask.view(matrix.shape)return mask
  1. 遍历模型中每一层的权重,为特定层的特定权重申请buffer并将权重加入__sparse_parameters中,用于后续mask的计算。

那么,如何确定到底为哪些层的哪些权重来申请buffer、生成mask呢?

init_model_for_pruning方法首先会根据是否导入了torchvision、以及torchvision的版本来确定一个sparse_parameter_list,其实际是以一个字典的形式记录着目前所支持的被稀疏的模块以及对应的参数:

torchvision_imported=True
try:import torchvision
except ImportError:print("[ASP][Warning] torchvision cannot be imported.")torchvision_imported=Falseif torchvision_imported:print("[ASP] torchvision is imported, can work with the MaskRCNN/KeypointRCNN from torchvision.")torchvision_version = str(torchvision.__version__)torchvision_version_major = int(torchvision_version.split('.')[0])torchvision_version_minor = int(torchvision_version.split('.')[1])if torchvision_version_major == 0 and torchvision_version_minor < 12:sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ['weight'], torch.nn.MultiheadAttention: ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight'], torchvision.ops.misc.Conv2d: ['weight']}else:    # Torchvision remove APIs that were deprecated before 0.8 (#5386) in 0.12.0, torchvision.ops.misc.Conv2d is removedsparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ['weight'], torch.nn.MultiheadAttention: ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']}else:sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ['weight'], torch.nn.MultiheadAttention: ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']}

除此之外,init_model_for_pruning方法还会根据传入的custom_layer_dict, whitelist, allowed_layer_names, disallowed_layer_names等参数来最终确定到底需要为当前模型中具体哪个模块的哪个参数进行稀疏化。除此之外,还会检查这些参数的shape是否符合要求,如果不符合要求会跳过该参数,不做稀疏。

接下来,init_model_for_pruning方法会为符合要求的参数创建一个buffer,命名为xxx_mma_mask,如果allow_recompute_mask=True,那么还会为参数创建一个额外的buffer,名为xxx_mma_pruned_p

最后,init_model_for_pruning方法会将所有符合条件的参数的相关信息加入__sparse_parameters

关于permutation search的部分暂且不提。

# 找到需要稀疏化且支持进行稀疏化的模块
def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallowed_layer_names):eligible_modules_list = []for name, mod in model.named_modules():if isinstance(mod, whitelist_layer_types) and name not in disallowed_layer_names:if allowed_layer_names is not None and name not in allowed_layer_names:continueeligible_modules_list.append((name, mod))return eligible_modules_list# 对需要且支持进行稀疏化的模块进行处理
for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names):add_sparse_attributes(name, sparse_module)# 对每个模块中的支持的参数类型进行处理
def add_sparse_attributes(module_name, module):sparse_parameters = sparse_parameter_list[type(module)]for p_name, p in module.named_parameters():if p_name in sparse_parameters and p.requires_grad:# check for NVIDIA's TC compatibility: we check along the horizontal directionif p.dtype == torch.float32 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): #User defines FP32 and APEX internally uses FP16 mathcontinueif p.dtype == torch.float16 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): #For Conv2d dim= K x CRS; we prune along Ccontinuep = p.t().contiguous()print("---------------{}", p.shape)model.state_dict[p_name] = p mask = torch.ones_like(p).bool()buffname = p_name.split(".")[-1] # buffer names cannot contain "."module.register_buffer('__%s_mma_mask' % buffname, mask)# 如果需要多次计算mask,那么需要将模型中被剪枝的参数保存下来,方便重新计算mask的时候使用# 因此需要额外申请一个用于存储原始数据的Buffer,以xxx_mma_pruned_p来命名if allow_recompute_mask:pruned = torch.zeros_like(p).cpu()module.register_buffer('__%s_mma_pruned_p' % buffname, pruned)else:pruned = Nonecls.__sparse_parameters.append((module_name, module, p_name, p, mask, pruned))else:continueif allow_permutation:......
  • init_optimizer_for_pruning

Call this method to monkey patch optimizer step function so that masks can be applied to gradients and weights during training.

You must call init_model_for_pruning(…) before calling init_optimizer_for_pruning(…)

官方给出的注释中,说明了两点:

首先,init_optimizer_for_pruning方法的作用是在训练时让mask参与梯度和权重的计算。

其次,强调调用init_optimizer_for_pruning前必须调用init_optimizer_for_pruning方法。

接下来是源代码:

    @classmethoddef init_optimizer_for_pruning(cls, optimizer):assert (cls.__optimizer is None), "ASP has initialized optimizer already."assert (cls.__calculate_mask is not None), "Called ASP.init_optimizer_for_pruning before ASP.init_model_for_pruning."# store pointer to original optimizer step methodcls.__optimizer = optimizercls.__optimizer.__step = optimizer.stepdef __step(opt_self, *args, **kwargs):# prune gradients before step methodwith torch.no_grad():for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:if p.grad is not None: #thx pjuddp.grad.mul_(mask)# call original optimizer step methodrval = opt_self.__step(*args, **kwargs)# prune parameters after step methodwith torch.no_grad():for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:p.mul_(mask)return rvalcls.__optimizer.step = types.MethodType(__step, cls.__optimizer)

init_optimizer_for_pruning方法主要通过对原来的optimizer的step方法进行重写,从而实现在optimizer每次执行step方法前后对梯度和权重进行剪枝。

首先ASP先将__optimier指向原始的optimizer,由于Python对复杂对象的赋值操作其实相当于是为optimizer建立了一个新的引用 ,二者指向同一个对象。

同时又为__optimizer创建了一个名为__step的引用,指向optimizer的step方法。

紧接着,init_optimizer_for_pruning方法定义了一个内部方法__step,该方法调用了原来optimizer的step方法,并在调用前后分别对__sparse_parameters中的梯度和参数进行剪枝。

最后,将新定义的__step方法绑定给__optimizer,并让optimizer的step方法指向它,实现optimizer的step方法的重写

为了方便理解,内存模型画了一个示意图:

在这里插入图片描述

  • compute_sparse_masks

做完了准备工作,下面才是真正enable sparsity特性的时候。

为了方便阅读,删掉了打印提示信息的部分代码

@classmethod
def compute_sparse_masks(cls):"""Call this method to enable sparsity.If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None."""with torch.no_grad():if cls.__allow_permutation:......for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:# mask在init_model_pruning中初始化为ones_like(p)# 如果mask.sum() < mask.numel(),则代表mask和p是稀疏的,之前已经enable 过sparsity特性了,现在是再次调用compute_mask方法if mask.sum() < mask.numel(): # when recalculating masks# restore dense parameter if allow_recompute_mask is enabled# allow_recompute_mask=True : pruned = zeros_like(p)# allow_recompute_mask=False: pruned = Noneassert (pruned is not None), "Unable to restore dense parameter because allow_recompute_mask == False"p.add_(pruned.cuda())mask.set_(cls.__calculate_mask(p))if pruned is not None: # stow away pruned weights to cpupruned.set_((p * (~mask)).cpu())p.mul_(mask) # in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights

跳过permutation search的部分,compute_sparse_masks方法先通过ask.sum() < mask.numel()?判断之前是否计算过mask的值,从而判断之前是否已经对模型进行过剪枝。如果之前已经进行过剪枝,则需要先从pruned中将之前保存的完整参数进行恢复。随后调用init_model_for_pruning方法中设置好的函数指针
.__calculate_mask为每个参数计算sparse mask,并将其乘上对应的参数,从而实现对参数的剪枝。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/68257.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

多线程并发服务器

代码&#xff1a; #include <sys/types.h> #include <sys/socket.h> #include <arpa/inet.h> #include <unistd.h> #define PORT 6666 //1024~49151 #define IP "192.168.122.130" //ifconfig查看本机IP #include <pthread.h> //…

安装使用IDEA,修改样式,配置服务,构建Maven项目(超级详细版)

目录 前言&#xff1a; 一&#xff0c;安装 1.1打开官网JetBrains: Essential tools for software developers and teams点击 Developer Tools&#xff0c;再点击 Intellij IDEA 2.点击下载​编辑 3.选择对应的版本&#xff0c;左边的 Ultimate 版本为旗舰版&#xff0c;需要…

Gin安装解决国内go 与 热加载

get 方式安装超时问题&#xff0c;国内直接用官网推荐的下面这个命令大概率是安装不成功的 go get -u github.com/gin-gonic/gin 可以在你的项目目录下执行下面几个命令&#xff1a; 比如我的项目在E:\Oproject\zl cmd E:\Oproject\zl>就在目录下执行 go env -w GO111…

MySQL8安装和删除教程 下载源码 保姆级(Windows)

删除 停止Mysql服务 管理员的权限来运行cmd&#xff0c;输入 net stop MySQL80 注意你电脑上的MySQL服务不一定是MySQL80,MySQL80是默认的&#xff0c;不是怎么办?在services.msc中找即可 下载一个小工具 geek:Geek下载打开软件&#xff0c;在列表中找到图片中的两项 sc…

代码随想录算法训练营第三十五天 | 860.柠檬水找零,406.根据身高重建队列,452. 用最少数量的箭引爆气球

代码随想录算法训练营第三十五天 | 860.柠檬水找零&#xff0c;406.根据身高重建队列&#xff0c;452. 用最少数量的箭引爆气球 860.柠檬水找零:eyes:题目总结:eyes: 406.根据身高重建队列:eyes:题目总结:eyes: 452. 用最少数量的箭引爆气球:eyes:题目总结:eyes: 860.柠檬水找零…

【Vue-Router】嵌套路由

footer.vue <template><div><router-view></router-view><hr><h1>我是父路由</h1><div><router-link to"/user">Login</router-link><router-link to"/user/reg" style"margin-left…

JS导出复杂多级表头的Excel

使用方式 1、安装依赖 npm install xlsx-js-style2、复制代码文件exportExcel.js至工程 https://github.com/EnthuDai/export-excel-in-one-line 3、在引入excel.js后调用 Excel.export(columns, dataSource, 导出文件名)4、代码demo 5、效果 页面excel 适用范围 对于使…

企业权限管理(七)-权限操作

1. 数据库与表结构 1.1 用户表 1.1.1 用户表信息描述 users 1.1.2 sql语句 CREATE TABLE users( id varchar2(32) default SYS_GUID() PRIMARY KEY, email VARCHAR2(50) UNIQUE NOT NULL, username VARCHAR2(50), PASSWORD VARCHAR2(50), phoneNum VARCHAR2(20), STATUS INT )…

Python实战之使用Python进行数据挖掘详解

一、Python数据挖掘 1.1 数据挖掘是什么&#xff1f; 数据挖掘是从大量的、不完全的、有噪声的、模糊的、随机的实际应用数据中&#xff0c;通过算法&#xff0c;找出其中的规律、知识、信息的过程。Python作为一门广泛应用的编程语言&#xff0c;拥有丰富的数据挖掘库&#…

【ARM Cache 系列文章 8 -- ARM DynamIQ 技术介绍

文章目录 DynamIQ 技术背景DynamIQ技术详解DynamIQ 与 big.LITTLEDynamIQ cluster 分类硬件支持 DynamIQ为什么适合人工智能&#xff1f; DynamIQ 技术背景 2017年3月21日下午&#xff0c;ARM在北京金隅喜来登酒店召开发布会&#xff0c;正式发布了全新的有针对人工智能及机器…

MYSQL-习题掌握

文章目录 SQL基本操作1 设计表操作1.1 关系表字段1.2 关系表创建1.3 关系表数据1.4 关系表关系 2 SQL操作2.1 SQL 1-102.2 SQL 11-202.3 SQL 21-302.4 SQL 31-402.5 SQL 41-50 SQL基本操作 1 设计表操作 1.1 关系表字段 1 学生表 student s_ids_names_births_sex学生编号学…

ArcGIS Pro发布地图服务(影像、矢量)

本文示例使用&#xff08;因为portal的授权的版本只有10.5的&#xff0c;故使用10.5进行示例&#xff09;&#xff1a; 软件:ArcGIS Pro3.0.1&#xff08;破解版&#xff09;&#xff0c; ArcGIS Portal10.5 当ArcGIS Pro和Portal不在一个机器或者版本不一样的时候&#xff0…