yolov7改进优化之蒸馏(一)

最近比较忙,有一段时间没更新了,最近yolov7用的比较多,总结一下。上一篇yolov5及yolov7实战之剪枝_CodingInCV的博客-CSDN博客 我们讲了通过剪枝来裁剪我们的模型,达到在精度损失不大的情况下,提高模型速度的目的。上一篇是从速度的角度,这一篇我们从检测性能的角度来改进yolov7(yolov5也类似)。
对于提高检测器的性能,我们除了可以从增加数据、修改模型结构、修改loss等模型本身的角度出发外,深度学习领域还有一个方式—蒸馏。简单的说,蒸馏就是让性能更强的模型(teacher, 参数量更大)来指导性能更弱student模型,从而提高student模型的性能。
蒸馏的方式有很多种,比较简单暴力的比如直接让student模型来拟合teacher模型的输出特征图,当然蒸馏也不是万能的,毕竟student模型和teacher模型的参数量有差距,student模型不一定能很好的学习teacher的知识,对于自己的任务有没有作用也需要尝试。
本篇选择的方法是去年CVPR上的针对目标检测的蒸馏算法:
yzd-v/FGD: Focal and Global Knowledge Distillation for Detectors (CVPR 2022) (github.com)
针对该方法的解读可以参考:FGD-CVPR2022:针对目标检测的焦点和全局蒸馏 - 知乎 (zhihu.com)
本篇暂时不涉及理论,重点在把这个方法集成到yolov7训练。步骤如下。

载入teacher模型

蒸馏首先需要有一个teacher模型,这个teacher模型一般和student同样结构,只是参数量更大、层数更多。比如对于yolov5,可以尝试用yolov5m来蒸馏yolov5s。
train.py增加一个命令行参数:

    parser.add_argument("--teacher-weights", type=str, default="", help="initial weights path")

在train函数中载入teacher weights,过程与原有的载入过程类似,注意,DP或者DDP模型也要对teacher模型做对应的处理。

# teacher modelif opt.teacher_weights:teacher_weights = opt.teacher_weights# with torch_distributed_zero_first(rank):#     teacher_weights = attempt_download(teacher_weights)  # download if not found locallyteacher_model = Model(teacher_weights, ch=3, nc=nc).to(device)  # create    # load state_dictckpt = torch.load(teacher_weights, map_location=device)  # load checkpointstate_dict = ckpt["model"].float().state_dict()  # to FP32teacher_model.load_state_dict(state_dict, strict=True)  # load#set to evalteacher_model.eval()#set IDetect to train mode# teacher_model.model[-1].train()logger.info(f"Load teacher model from {teacher_weights}")  # report# DP modeif cuda and rank == -1 and torch.cuda.device_count() > 1:model = torch.nn.DataParallel(model)if opt.teacher_weights:teacher_model = torch.nn.DataParallel(teacher_model)# SyncBatchNormif opt.sync_bn and cuda and rank != -1:model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)logger.info("Using SyncBatchNorm()")if opt.teacher_weights:teacher_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(teacher_model).to(device)

teacher模型不进行梯度计算,因此:

if opt.teacher_weights:for param in teacher_model.parameters():param.requires_grad = False

蒸馏Loss

蒸馏loss是计算teacher模型的一层或者多层与student的对应层的相似度,监督student模型向teacher模型靠近。对于yolov7,可以去监督三个特征层。
参考FGD的开源代码,我们在loss.py中增加一个FeatureLoss类, 参数暂时使用默认:

class FeatureLoss(nn.Module):"""PyTorch version of `Feature Distillation for General Detectors`Args:student_channels(int): Number of channels in the student's feature map.teacher_channels(int): Number of channels in the teacher's feature map. temp (float, optional): Temperature coefficient. Defaults to 0.5.name (str): the loss name of the layeralpha_fgd (float, optional): Weight of fg_loss. Defaults to 0.001beta_fgd (float, optional): Weight of bg_loss. Defaults to 0.0005gamma_fgd (float, optional): Weight of mask_loss. Defaults to 0.0005lambda_fgd (float, optional): Weight of relation_loss. Defaults to 0.000005"""def __init__(self,student_channels,teacher_channels,temp=0.5,alpha_fgd=0.001,beta_fgd=0.0005,gamma_fgd=0.001,lambda_fgd=0.000005,):super(FeatureLoss, self).__init__()self.temp = tempself.alpha_fgd = alpha_fgdself.beta_fgd = beta_fgdself.gamma_fgd = gamma_fgdself.lambda_fgd = lambda_fgdif student_channels != teacher_channels:self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0)else:self.align = Noneself.conv_mask_s = nn.Conv2d(teacher_channels, 1, kernel_size=1)self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1)self.channel_add_conv_s = nn.Sequential(nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),nn.LayerNorm([teacher_channels//2, 1, 1]),nn.ReLU(inplace=True),  # yapf: disablenn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))self.channel_add_conv_t = nn.Sequential(nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),nn.LayerNorm([teacher_channels//2, 1, 1]),nn.ReLU(inplace=True),  # yapf: disablenn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))self.reset_parameters()def forward(self,preds_S,preds_T,gt_bboxes,img_metas):"""Forward function.Args:preds_S(Tensor): Bs*C*H*W, student's feature mappreds_T(Tensor): Bs*C*H*W, teacher's feature mapgt_bboxes(tuple): Bs*[nt*4], pixel decimal: (tl_x, tl_y, br_x, br_y)img_metas (list[dict]): Meta information of each image, e.g.,image size, scaling factor, etc."""assert preds_S.shape[-2:] == preds_T.shape[-2:], 'the output dim of teacher and student differ'device = gt_bboxes.deviceself.to(device)if self.align is not None:preds_S = self.align(preds_S)N,C,H,W = preds_S.shapeS_attention_t, C_attention_t = self.get_attention(preds_T, self.temp)S_attention_s, C_attention_s = self.get_attention(preds_S, self.temp)Mask_fg = torch.zeros_like(S_attention_t)# Mask_bg = torch.ones_like(S_attention_t)wmin,wmax,hmin,hmax = [],[],[],[]img_h, img_w = img_metasbboxes = gt_bboxes[:,2:6]#xywh2xyxybboxes = xywh2xyxy(bboxes)new_boxxes = torch.ones_like(bboxes)new_boxxes[:, 0] = torch.floor(bboxes[:, 0]*W)new_boxxes[:, 2] = torch.ceil(bboxes[:, 2]*W)new_boxxes[:, 1] = torch.floor(bboxes[:, 1]*H)new_boxxes[:, 3] = torch.ceil(bboxes[:, 3]*H)#to intnew_boxxes = new_boxxes.int()for i in range(N):new_boxxes_i = new_boxxes[torch.where(gt_bboxes[:,0]==i)]wmin.append(new_boxxes_i[:, 0])wmax.append(new_boxxes_i[:, 2])hmin.append(new_boxxes_i[:, 1])hmax.append(new_boxxes_i[:, 3])area = 1.0/(hmax[i].view(1,-1)+1-hmin[i].view(1,-1))/(wmax[i].view(1,-1)+1-wmin[i].view(1,-1))for j in range(len(new_boxxes_i)):Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1] = \torch.maximum(Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1], area[0][j])Mask_bg = torch.where(Mask_fg > 0, 0., 1.)Mask_bg_sum = torch.sum(Mask_bg, dim=(1,2))Mask_bg[Mask_bg_sum>0] /= Mask_bg_sum[Mask_bg_sum>0].unsqueeze(1).unsqueeze(2)fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg, C_attention_s, C_attention_t, S_attention_s, S_attention_t)mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, S_attention_s, S_attention_t)rela_loss = self.get_rela_loss(preds_S, preds_T)loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \+ self.gamma_fgd * mask_loss + self.lambda_fgd * rela_lossreturn loss, loss.detach()def get_attention(self, preds, temp):""" preds: Bs*C*W*H """N, C, H, W= preds.shapevalue = torch.abs(preds)# Bs*W*Hfea_map = value.mean(axis=1, keepdim=True)S_attention = (H * W * F.softmax((fea_map/temp).view(N,-1), dim=1)).view(N, H, W)# Bs*Cchannel_map = value.mean(axis=2,keepdim=False).mean(axis=2,keepdim=False)C_attention = C * F.softmax(channel_map/temp, dim=1)return S_attention, C_attentiondef get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, S_t):loss_mse = nn.MSELoss(reduction='sum')Mask_fg = Mask_fg.unsqueeze(dim=1)Mask_bg = Mask_bg.unsqueeze(dim=1)C_t = C_t.unsqueeze(dim=-1)C_t = C_t.unsqueeze(dim=-1)S_t = S_t.unsqueeze(dim=1)fea_t= torch.mul(preds_T, torch.sqrt(S_t))fea_t = torch.mul(fea_t, torch.sqrt(C_t))fg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_fg))bg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_bg))fea_s = torch.mul(preds_S, torch.sqrt(S_t))fea_s = torch.mul(fea_s, torch.sqrt(C_t))fg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_fg))bg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_bg))fg_loss = loss_mse(fg_fea_s, fg_fea_t)/len(Mask_fg)bg_loss = loss_mse(bg_fea_s, bg_fea_t)/len(Mask_bg)return fg_loss, bg_lossdef get_mask_loss(self, C_s, C_t, S_s, S_t):mask_loss = torch.sum(torch.abs((C_s-C_t)))/len(C_s) + torch.sum(torch.abs((S_s-S_t)))/len(S_s)return mask_lossdef spatial_pool(self, x, in_type):batch, channel, width, height = x.size()input_x = x# [N, C, H * W]input_x = input_x.view(batch, channel, height * width)# [N, 1, C, H * W]input_x = input_x.unsqueeze(1)# [N, 1, H, W]if in_type == 0:context_mask = self.conv_mask_s(x)else:context_mask = self.conv_mask_t(x)# [N, 1, H * W]context_mask = context_mask.view(batch, 1, height * width)# [N, 1, H * W]context_mask = F.softmax(context_mask, dim=2)# [N, 1, H * W, 1]context_mask = context_mask.unsqueeze(-1)# [N, 1, C, 1]context = torch.matmul(input_x, context_mask)# [N, C, 1, 1]context = context.view(batch, channel, 1, 1)return contextdef get_rela_loss(self, preds_S, preds_T):loss_mse = nn.MSELoss(reduction='sum')context_s = self.spatial_pool(preds_S, 0)context_t = self.spatial_pool(preds_T, 1)out_s = preds_Sout_t = preds_Tchannel_add_s = self.channel_add_conv_s(context_s)out_s = out_s + channel_add_schannel_add_t = self.channel_add_conv_t(context_t)out_t = out_t + channel_add_trela_loss = loss_mse(out_s, out_t)/len(out_s)return rela_lossdef last_zero_init(self, m):if isinstance(m, nn.Sequential):constant_init(m[-1], val=0)else:constant_init(m, val=0)def reset_parameters(self):kaiming_init(self.conv_mask_s, mode='fan_in')kaiming_init(self.conv_mask_t, mode='fan_in')self.conv_mask_s.inited = Trueself.conv_mask_t.inited = Trueself.last_zero_init(self.channel_add_conv_s)self.last_zero_init(self.channel_add_conv_t)

实例化FeatureLoss

在train.py中,实例化我们定义的FeatureLoss,由于我们要蒸馏三层,所以需要定一个蒸馏损失的数组:

if opt.teacher_weights:student_kd_layers = hyp["student_kd_layers"]teacher_kd_layers = hyp["teacher_kd_layers"]dump_image = torch.zeros((1, 3, imgsz, imgsz), device=device)targets = torch.Tensor([[0, 0, 0, 0, 0, 0]]).to(device)_, features = model(dump_image, extra_features = student_kd_layers)  # forward_, teacher_features = teacher_model(dump_image,extra_features=teacher_kd_layers)kd_losses = []for i in range(len(features)):feature = features[i]teacher_feature = teacher_features[i]_, student_channels, _ , _ = feature.shape_, teacher_channels, _ , _ = teacher_feature.shapekd_losses.append(FeatureLoss(student_channels,teacher_channels))

其中hyp[‘xxx_kd_layers’]是用于指定我们要蒸馏的层序号。
为了提取出我们需要的层的特征图,我们还需要对模型推理的代码进行修改,这个放在下一篇,这一篇先把主要流程过一遍。

蒸馏训练

与普通loss一样,在训练中,首先计算蒸馏loss, 然后进行反向传播,区别只是计算蒸馏loss时需要使用teacher模型也对数据进行推理。

if opt.teacher_weights:pred, features = model(imgs, extra_features = student_kd_layers)  # forward_, teacher_features = teacher_model(imgs, extra_features = teacher_kd_layers)if "loss_ota" not in hyp or hyp["loss_ota"] == 1 and epoch >= ota_start:loss, loss_items = compute_loss_ota(pred, targets.to(device), imgs)else:loss, loss_items = compute_loss(pred, targets.to(device))  # loss scaled by batch_size# kd lossloss_items = torch.cat((loss_items[0].unsqueeze(0), loss_items[1].unsqueeze(0), loss_items[2].unsqueeze(0), torch.zeros(1, device=device), loss_items[3].unsqueeze(0)))loss_items[-1]*=imgs.shape[0]for i in range(len(features)):feature = features[i]teacher_feature = teacher_features[i]kd_loss, kd_loss_item = kd_losses[i](feature, teacher_feature, targets.to(device), [imgsz,imgsz])loss += kd_lossloss_items[3] += kd_loss_itemloss_items[4] += kd_loss_item

在这里,我们将kd_loss累加到了loss上。计算出总的loss,其他就与普通训练一样了。

结语

这篇文章简述了一下yolov7的蒸馏过程,更多细节将在下一篇中讲述。
f77d79a3b79d6d9849231e64c8e1cdfa~tplv-dy-resize-origshort-autoq-75_330.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/139222.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

王兴投资5G小基站

边缘计算社区获悉,近期深圳佳贤通信正式完成数亿元股权融资,本轮融资由美团龙珠领投。本轮融资资金主要用于技术研发、市场拓展等,将进一步巩固和扩大佳贤通信在5G小基站领域的技术及市场领先地位。 01 佳贤通信是什么样的公司? 深…

Databend hash join spill 设计与实现 | Data Infra 第 16 期

本周六,我们将迎来最新一期的 Data Infra 直播活动,本次活动我们邀请到了 Databend 研发工程师-王旭东,与大家分享主题为《 Databend hash join spill 设计与实现 》的相关知识。 通过本次分享,我们能更加了解 Databend 的 hash …

基于 SaaS 搭建的党建小程序源码系统 带完整的搭建教程

随着互联网技术的发展和应用的普及,传统的党建模式已经难以满足现代社会的需求。为了更好地服务党员和群众,提高党组织的凝聚力和战斗力,基于 SaaS搭建的党建小程序源码系统应运而生。小程序的出现可以很好的解决大多数问题,方便了…

Python突破浏览器TLS/JA3 指纹

JA3 是一种创建 SSL/TLS 客户端指纹的方法,一般一个网站的证书是不变的,所以浏览器指纹也是稳定的,能区分不同的客户端。 requests库 Python requests库请求一个带JA3指纹网站的结果: import requestsheaders {authority: tls…

2023年Q3季度国内手机大盘销额下滑2%,TOP品牌销售数据分析

根据Canalys机构发布的最新报告,2023年第三季度,全球智能手机市场出货量仅下跌1%,可以认为目前全球手机市场的下滑势头有所减缓。而国内线上市场的表现也类似。 根据鲸参谋数据显示,今年Q3京东平台手机累计销量约1100万件&#xf…

2023_Spark_实验十四:SparkSQL入门操作

1、将emp.csv、dept.csv文件上传到分布式环境,再用 hdfs dfs -put dept.csv /input/ hdfs dfs -put emp.csv /input/ 将本地文件put到hdfs文件系统的input目录下 2、或者调用本地文件也可以。区别:sc.textFile("file:///D:\\temp\\emp.csv&qu…

药物滥用第二篇介绍

MTD: 美沙酮(Methadone),是一种有机化合物,化学式为C21H27NO,为μ阿片受体激动剂,药效与吗啡类似,具有镇痛作用,并可产生呼吸抑制、缩瞳、镇静等作用。与吗啡比较&#x…

物流监管:智慧仓储数据可视化监控平台

随着市场竞争加剧和市场需求的不断提高,企业亟需更加高效、智能且可靠的仓储物流管理方式,以提升企业的物流效率,减少其输出成本,有效应对市场上的变化和挑战。 图扑自研 HT for Web 产品搭建的 2D 智慧仓储可视化平台&#xff0c…

深入理解Huffman编码:原理、代码示例与应用

目录 ​编辑 介绍 Huffman编码的原理 信息理论背景 频率统计 Huffman树 Huffman编码的代码示例 数据结构 权重选择 Huffman编码生成 完整示例 完整代码 测试截图 Huffman编码的应用 总结 介绍 在这个数字时代,数据的有效压缩和传输变得至关重要。Hu…

一文学会使用WebRTC API

WebRTC(Web Real-Time Communication)是一项开放标准和技术集合,由 W3C 和 IETF 等组织共同推动和维护,旨在通过Web浏览器实现实时通信和媒体流传输。WebRTC于2011年6月1日开源并在Google、Mozilla、Opera支持下被纳入万维网联盟的…

校园智慧党建小程序源码系统 带完整的搭建教程

大家好啊,今天来给大家分享一款校园智慧党建小程序源码系统。一起来看看吧。以下是部分功能代码图: 系统特色功能一览: 智能化管理:采用人工智能、大数据、云计算等技术手段,实现自动化、智能化管理。例如&#xff0c…

Windows:Arduino IDE 开发环境配置【保姆级】

物联网开发学习笔记——目录索引 参考官网:Arduino - Home Arduino是一款简单易学且功能丰富的开源平台,包含硬件部分(各种型号的Arduino开发板)和软件部分(Arduino IDE)以及广大爱好者和专业人员共同搭建和维护的互联…