【yolov8】yolov8剪枝训练流程

yolov8剪枝训练流程

流程:

  • 约束
  • 剪枝
  • 微调

一、正常训练

yolo train model=./weights/yolov8s.pt data=yolo_bvn.yaml epochs=100 amp=False project=prun name=train

二、约束训练

2.1 修改YOLOv8代码:

ultralytics/yolo/engine/trainer.py
添加内容:

# Backwardself.scaler.scale(self.loss).backward()# ========== 新增 ==========l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():if isinstance(m, nn.BatchNorm2d):m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))# ========== 新增 ==========# Optimize - https://pytorch.org/docs/master/notes/amp_examples.htmlif ni - last_opt_step >= self.accumulate:self.optimizer_step()last_opt_step = ni

2.2 训练

需要注意的就是amp=False

yolo train model=prunt/train/weights/best.pt data=yolo_bvn.yaml epochs=100 amp=False project=prun name=constraint

训练完会得到一个best.pt和last.pt,推荐用last.pt

三、剪枝

上一步得到的last.pt作为剪枝对象,运行项目中的prun.py文件:

*这里的剪枝代码仅适用yolov8原模型,如有模块/模型的更改,则需要修改剪枝代码*

运行完会得到prune.pt和prune.onnx可以在netron.app网站拖入onnx文件查看是否剪枝成功了,成功的话可以看到某些通道数字为单数或者一些不规律的数字,如下图:

在这里插入图片描述

左侧为未剪枝的模型,右侧为剪枝后的模型。

关于yolov8剪枝有以下几点值得注意:

Pipeline:

    1. 为模型的BN增加L1约束,lambda用1e-2左右
    1. 剪枝模型使用的是全局阈值
    1. finetune模型时,一定要注意,此时需要去掉L1约束,最终的final的版本一定是去掉的(ultralytics/yolo/engine/trainer.py中注释)
    1. 对于yolo.model.named_parameters()循环,需要设置p.requires_gradTrue

Future work:

    1. 不能剪枝的layer,其实可以不用约束
    1. 对于低于全局阈值的,可以删掉整个module
    1. keep channels,对于保留的channels,它应该能整除n才是最合适的,否则硬件加速比较差
  • n怎么选呢?一般fp16时,n为8;int8时,n为16

四、 回调训练(finetune)

回调训练的唯一关键点就在于不让模型从yaml文件加载结构,直接加载pt文件

两种方法(因yolov8版本不同而选择不同方法):

方法一:

3.1 首先要把第一步约束训练的代码注释掉
3.2 修改相关代码,使模型不加载yaml文件

修改位置:yolo/engine/model.py的443行左右

self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)  # calls Model(cfg, weights)# ========== 新增该行代码 ==========self.model = weights# ========== 新增该行代码 ==========return ckpt

方法二:

3.1 首先要把第一步约束训练的代码注释掉
3.2 修改相关代码,使模型不加载yaml文件

修改位置:yolo/engine/model.py的335行左右

if not args.get('resume'):  # manually set model only if not resuming# self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)# self.model = self.trainer.model######################上面两行注释掉,添加下面一行#####self.trainer.model = self.model.train()##########################修改####################self.trainer.hub_session = self.session  # attach optional HUB session
3.3 修改完代码就可以进行finetun训练了

命令行输入:

yolo train model=prun/prune/weights/last_prune.pt data="yolo_bvn.yaml" amp=False epochs=100 project=prun name=finetune device=0

五、结果展示:

5.1模型大小:ONNX模型大小从42M减少到34M

在这里插入图片描述

5.2PR曲线:

正常训练约束训练100轮微调
在这里插入图片描述在这里插入图片描述在这里插入图片描述

5.3实测视频在ubuntu上检测速度:

未剪枝:平均每帧5毫秒

剪枝后:平均每帧3.7毫秒

六、问题及解决:

对剪枝完的yolov8进行finetune时遇到RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_mm)

self.proj 可能不在与 pred_dist 相同的设备上。这可能是因为 self.proj 被指定在 CPU 上,而 pred_dist 在 GPU 上(或反之)。
要解决这个问题,需要确保两个张量位于相同的设备上。可以使用 to() 方法将 self.proj 放到与 pred_dist 相同的设备上。

解决:在loss.py添加如下代码:

def bbox_decode(self, anchor_points, pred_dist):"""Decode predicted object bounding box coordinates from anchor points and distribution."""if self.use_dfl:b, a, c = pred_dist.shape  # batch, anchors, channels####添加device = pred_dist.deviceself.proj = self.proj.to(device)#####pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)return dist2bbox(pred_dist, anchor_points, xywh=False)

七、参考:

7.1 【yolov8系列】 yolov8 目标检测的模型剪枝_yolov8 剪枝-CSDN博客
7.2 YOLOv8剪枝全过程-CSDN博客

7.3 剪枝与重参第七课:YOLOv8剪枝-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/661550.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【LocalAI】(10):在autodl上编译embeddings.cpp项目,转换bge-base-zh-v1.5模型成ggml格式,本地运行main成功

1,关于 localai LocalAI 是一个用于本地推理的,与 OpenAI API 规范兼容的 REST API。 它允许您在本地使用消费级硬件运行 LLM(不仅如此),支持与 ggml 格式兼容的多个模型系列。支持CPU硬件/GPU硬件。 【LocalAI】&…

Mac数据恢复软件快速比较:适用于Macbook的10佳恢复软件

数据丢失导致无数个人和组织每天损失大量资金。更糟糕的是,某些文件具有货币价值和情感意义,使它们不可替代,并使数据恢复成为唯一可行的选择。最好的消息是Mac用户可以从各种数据恢复程序中进行选择。为了帮助您尽可能快速、轻松地恢复丢失的…

vue3步骤条带边框点击切换高亮

如果是div使用clip-path: polygon(0% 0%, 92% 0%, 100% 50%, 92% 100%, 0% 100%, 8% 50%);进行裁剪加边框没实现成功。目前这个使用svg完成带边框的。 形状可自行更改path 标签里的 :d“[num ! 1 ? ‘M 0 0 L 160 0 L 176 18 L 160 38 L 0 38 L 15.5 18 Z’ : ‘M 0,0 L 160,0…

安卓中对象序列化面试问题及回答

1. 什么是对象的序列化? 答: 序列化是将对象转换为字节流的过程,以便将其存储在文件、数据库或通过网络传输。反序列化则是将字节流重新转换为对象的过程。 2. 为什么在 Android 开发中需要对象的序列化? 答: 在 An…

锂电池SOH预测 | 基于BP神经网络的锂电池SOH预测(附matlab完整源码)

锂电池SOH预测 锂电池SOH预测完整代码锂电池SOH预测 锂电池的SOH(状态健康度)预测是一项重要的任务,它可以帮助确定电池的健康状况和剩余寿命,从而优化电池的使用和维护策略。 SOH预测可以通过多种方法实现,其中一些常用的方法包括: 容量衰减法:通过监测电池的容量衰减…

css利用transform:skew()属性画一个大屏的背景斜面四边形特效

在工作工程中需要写一个如下的大屏背景&#xff0c;是由几个斜面做成的效果 使用css transform function中的skew()方法实现画其中一个斜面&#xff0c;然后调整背景色实现 写一个div <div class"skew_container test-2"><div class"skew_container_it…

2024五一杯数学建模C题思路分享 - 煤矿深部开采冲击地压危险预测

文章目录 1 赛题选题分析 2 解题思路2.1 问题重述2.2 第一问完整思路2.2 二、三问思路更新 3 最新思路更新 1 赛题 C题 煤矿深部开采冲击地压危险预测 煤炭是中国的主要能源和重要的工业原料。然而&#xff0c;随着开采深度的增加&#xff0c;地应力增大&#xff0c;井下煤岩动…

HarmonyOS 4.0(鸿蒙开发)01 - 怎么学习鸿蒙引导篇

作为公司的全栈开发工程师 以及 未来的发展是有鸿蒙这个阶段的&#xff0c;以及本身具有这个技术栈由此后续会分享自己在实战中学习到的东西&#xff0c;碰到的bug都会分享出来&#xff0c;这是引导篇期待后续的更新 学习目标&#xff1a; 理解HarmonyOS操作系统的架构和开发…

探索潜力:中心化交易所平台币的对比分析

核心观点 平台币在过去一年里表现差异显著&#xff1a; 在过去的一年里&#xff0c;只有少数几个平台币如BMX、BGB和MX的涨幅超过了100%。相比之下&#xff0c;由于市值较高&#xff0c;BNB和OKB的涨幅相对较低。 回购和销毁机制在平台币价值中起决定性作用&#xff1a; 像M…

云备份项目->配置环境

升级gcc到7.3版本 sudo yum install centos-release-scl-rh centos-release-scl sudo yum install devtoolset-7-gcc devtoolset-7-gcc-c source /opt/rh/devtoolset-7/enable echo "source /opt/rh/devtoolset-7/enable" >> ~/.bashrc 安装Jsoncpp库 sud…

Ubuntu如何更换 PyTorch 版本

环境&#xff1a; Ubuntu22.04 WLS2 问题描述&#xff1a; Ubuntu如何更换 PyTorch 版本考虑安装一个为 CUDA 11.5 编译的 PyTorch 版本。如何安装旧版本 解决方案&#xff1a; 决定不升级CUDA版本&#xff0c;而是使用一个与CUDA 11.5兼容的PyTorch版本&#xff0c;您可…

基于ssm+vue+Mysql的药源购物网站

开发语言&#xff1a;Java框架&#xff1a;ssmJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;Maven3.…