最近比较忙,有一段时间没更新了,最近yolov7用的比较多,总结一下。上一篇yolov5及yolov7实战之剪枝_CodingInCV的博客-CSDN博客 我们讲了通过剪枝来裁剪我们的模型,达到在精度损失不大的情况下,提高模型速度的目的。上一篇是从速度的角度,这一篇我们从检测性能的角度来改进yolov7(yolov5也类似)。
对于提高检测器的性能,我们除了可以从增加数据、修改模型结构、修改loss等模型本身的角度出发外,深度学习领域还有一个方式—蒸馏。简单的说,蒸馏就是让性能更强的模型(teacher, 参数量更大)来指导性能更弱student模型,从而提高student模型的性能。
蒸馏的方式有很多种,比较简单暴力的比如直接让student模型来拟合teacher模型的输出特征图,当然蒸馏也不是万能的,毕竟student模型和teacher模型的参数量有差距,student模型不一定能很好的学习teacher的知识,对于自己的任务有没有作用也需要尝试。
本篇选择的方法是去年CVPR上的针对目标检测的蒸馏算法:
yzd-v/FGD: Focal and Global Knowledge Distillation for Detectors (CVPR 2022) (github.com)
针对该方法的解读可以参考:FGD-CVPR2022:针对目标检测的焦点和全局蒸馏 - 知乎 (zhihu.com)
本篇暂时不涉及理论,重点在把这个方法集成到yolov7训练。步骤如下。
载入teacher模型
蒸馏首先需要有一个teacher模型,这个teacher模型一般和student同样结构,只是参数量更大、层数更多。比如对于yolov5,可以尝试用yolov5m来蒸馏yolov5s。
train.py增加一个命令行参数:
parser.add_argument("--teacher-weights", type=str, default="", help="initial weights path")
在train函数中载入teacher weights,过程与原有的载入过程类似,注意,DP或者DDP模型也要对teacher模型做对应的处理。
# teacher modelif opt.teacher_weights:teacher_weights = opt.teacher_weights# with torch_distributed_zero_first(rank):# teacher_weights = attempt_download(teacher_weights) # download if not found locallyteacher_model = Model(teacher_weights, ch=3, nc=nc).to(device) # create # load state_dictckpt = torch.load(teacher_weights, map_location=device) # load checkpointstate_dict = ckpt["model"].float().state_dict() # to FP32teacher_model.load_state_dict(state_dict, strict=True) # load#set to evalteacher_model.eval()#set IDetect to train mode# teacher_model.model[-1].train()logger.info(f"Load teacher model from {teacher_weights}") # report# DP modeif cuda and rank == -1 and torch.cuda.device_count() > 1:model = torch.nn.DataParallel(model)if opt.teacher_weights:teacher_model = torch.nn.DataParallel(teacher_model)# SyncBatchNormif opt.sync_bn and cuda and rank != -1:model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)logger.info("Using SyncBatchNorm()")if opt.teacher_weights:teacher_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(teacher_model).to(device)
teacher模型不进行梯度计算,因此:
if opt.teacher_weights:for param in teacher_model.parameters():param.requires_grad = False
蒸馏Loss
蒸馏loss是计算teacher模型的一层或者多层与student的对应层的相似度,监督student模型向teacher模型靠近。对于yolov7,可以去监督三个特征层。
参考FGD的开源代码,我们在loss.py中增加一个FeatureLoss类, 参数暂时使用默认:
class FeatureLoss(nn.Module):"""PyTorch version of `Feature Distillation for General Detectors`Args:student_channels(int): Number of channels in the student's feature map.teacher_channels(int): Number of channels in the teacher's feature map. temp (float, optional): Temperature coefficient. Defaults to 0.5.name (str): the loss name of the layeralpha_fgd (float, optional): Weight of fg_loss. Defaults to 0.001beta_fgd (float, optional): Weight of bg_loss. Defaults to 0.0005gamma_fgd (float, optional): Weight of mask_loss. Defaults to 0.0005lambda_fgd (float, optional): Weight of relation_loss. Defaults to 0.000005"""def __init__(self,student_channels,teacher_channels,temp=0.5,alpha_fgd=0.001,beta_fgd=0.0005,gamma_fgd=0.001,lambda_fgd=0.000005,):super(FeatureLoss, self).__init__()self.temp = tempself.alpha_fgd = alpha_fgdself.beta_fgd = beta_fgdself.gamma_fgd = gamma_fgdself.lambda_fgd = lambda_fgdif student_channels != teacher_channels:self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0)else:self.align = Noneself.conv_mask_s = nn.Conv2d(teacher_channels, 1, kernel_size=1)self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1)self.channel_add_conv_s = nn.Sequential(nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),nn.LayerNorm([teacher_channels//2, 1, 1]),nn.ReLU(inplace=True), # yapf: disablenn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))self.channel_add_conv_t = nn.Sequential(nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),nn.LayerNorm([teacher_channels//2, 1, 1]),nn.ReLU(inplace=True), # yapf: disablenn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))self.reset_parameters()def forward(self,preds_S,preds_T,gt_bboxes,img_metas):"""Forward function.Args:preds_S(Tensor): Bs*C*H*W, student's feature mappreds_T(Tensor): Bs*C*H*W, teacher's feature mapgt_bboxes(tuple): Bs*[nt*4], pixel decimal: (tl_x, tl_y, br_x, br_y)img_metas (list[dict]): Meta information of each image, e.g.,image size, scaling factor, etc."""assert preds_S.shape[-2:] == preds_T.shape[-2:], 'the output dim of teacher and student differ'device = gt_bboxes.deviceself.to(device)if self.align is not None:preds_S = self.align(preds_S)N,C,H,W = preds_S.shapeS_attention_t, C_attention_t = self.get_attention(preds_T, self.temp)S_attention_s, C_attention_s = self.get_attention(preds_S, self.temp)Mask_fg = torch.zeros_like(S_attention_t)# Mask_bg = torch.ones_like(S_attention_t)wmin,wmax,hmin,hmax = [],[],[],[]img_h, img_w = img_metasbboxes = gt_bboxes[:,2:6]#xywh2xyxybboxes = xywh2xyxy(bboxes)new_boxxes = torch.ones_like(bboxes)new_boxxes[:, 0] = torch.floor(bboxes[:, 0]*W)new_boxxes[:, 2] = torch.ceil(bboxes[:, 2]*W)new_boxxes[:, 1] = torch.floor(bboxes[:, 1]*H)new_boxxes[:, 3] = torch.ceil(bboxes[:, 3]*H)#to intnew_boxxes = new_boxxes.int()for i in range(N):new_boxxes_i = new_boxxes[torch.where(gt_bboxes[:,0]==i)]wmin.append(new_boxxes_i[:, 0])wmax.append(new_boxxes_i[:, 2])hmin.append(new_boxxes_i[:, 1])hmax.append(new_boxxes_i[:, 3])area = 1.0/(hmax[i].view(1,-1)+1-hmin[i].view(1,-1))/(wmax[i].view(1,-1)+1-wmin[i].view(1,-1))for j in range(len(new_boxxes_i)):Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1] = \torch.maximum(Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1], area[0][j])Mask_bg = torch.where(Mask_fg > 0, 0., 1.)Mask_bg_sum = torch.sum(Mask_bg, dim=(1,2))Mask_bg[Mask_bg_sum>0] /= Mask_bg_sum[Mask_bg_sum>0].unsqueeze(1).unsqueeze(2)fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg, C_attention_s, C_attention_t, S_attention_s, S_attention_t)mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, S_attention_s, S_attention_t)rela_loss = self.get_rela_loss(preds_S, preds_T)loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \+ self.gamma_fgd * mask_loss + self.lambda_fgd * rela_lossreturn loss, loss.detach()def get_attention(self, preds, temp):""" preds: Bs*C*W*H """N, C, H, W= preds.shapevalue = torch.abs(preds)# Bs*W*Hfea_map = value.mean(axis=1, keepdim=True)S_attention = (H * W * F.softmax((fea_map/temp).view(N,-1), dim=1)).view(N, H, W)# Bs*Cchannel_map = value.mean(axis=2,keepdim=False).mean(axis=2,keepdim=False)C_attention = C * F.softmax(channel_map/temp, dim=1)return S_attention, C_attentiondef get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, S_t):loss_mse = nn.MSELoss(reduction='sum')Mask_fg = Mask_fg.unsqueeze(dim=1)Mask_bg = Mask_bg.unsqueeze(dim=1)C_t = C_t.unsqueeze(dim=-1)C_t = C_t.unsqueeze(dim=-1)S_t = S_t.unsqueeze(dim=1)fea_t= torch.mul(preds_T, torch.sqrt(S_t))fea_t = torch.mul(fea_t, torch.sqrt(C_t))fg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_fg))bg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_bg))fea_s = torch.mul(preds_S, torch.sqrt(S_t))fea_s = torch.mul(fea_s, torch.sqrt(C_t))fg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_fg))bg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_bg))fg_loss = loss_mse(fg_fea_s, fg_fea_t)/len(Mask_fg)bg_loss = loss_mse(bg_fea_s, bg_fea_t)/len(Mask_bg)return fg_loss, bg_lossdef get_mask_loss(self, C_s, C_t, S_s, S_t):mask_loss = torch.sum(torch.abs((C_s-C_t)))/len(C_s) + torch.sum(torch.abs((S_s-S_t)))/len(S_s)return mask_lossdef spatial_pool(self, x, in_type):batch, channel, width, height = x.size()input_x = x# [N, C, H * W]input_x = input_x.view(batch, channel, height * width)# [N, 1, C, H * W]input_x = input_x.unsqueeze(1)# [N, 1, H, W]if in_type == 0:context_mask = self.conv_mask_s(x)else:context_mask = self.conv_mask_t(x)# [N, 1, H * W]context_mask = context_mask.view(batch, 1, height * width)# [N, 1, H * W]context_mask = F.softmax(context_mask, dim=2)# [N, 1, H * W, 1]context_mask = context_mask.unsqueeze(-1)# [N, 1, C, 1]context = torch.matmul(input_x, context_mask)# [N, C, 1, 1]context = context.view(batch, channel, 1, 1)return contextdef get_rela_loss(self, preds_S, preds_T):loss_mse = nn.MSELoss(reduction='sum')context_s = self.spatial_pool(preds_S, 0)context_t = self.spatial_pool(preds_T, 1)out_s = preds_Sout_t = preds_Tchannel_add_s = self.channel_add_conv_s(context_s)out_s = out_s + channel_add_schannel_add_t = self.channel_add_conv_t(context_t)out_t = out_t + channel_add_trela_loss = loss_mse(out_s, out_t)/len(out_s)return rela_lossdef last_zero_init(self, m):if isinstance(m, nn.Sequential):constant_init(m[-1], val=0)else:constant_init(m, val=0)def reset_parameters(self):kaiming_init(self.conv_mask_s, mode='fan_in')kaiming_init(self.conv_mask_t, mode='fan_in')self.conv_mask_s.inited = Trueself.conv_mask_t.inited = Trueself.last_zero_init(self.channel_add_conv_s)self.last_zero_init(self.channel_add_conv_t)
实例化FeatureLoss
在train.py中,实例化我们定义的FeatureLoss,由于我们要蒸馏三层,所以需要定一个蒸馏损失的数组:
if opt.teacher_weights:student_kd_layers = hyp["student_kd_layers"]teacher_kd_layers = hyp["teacher_kd_layers"]dump_image = torch.zeros((1, 3, imgsz, imgsz), device=device)targets = torch.Tensor([[0, 0, 0, 0, 0, 0]]).to(device)_, features = model(dump_image, extra_features = student_kd_layers) # forward_, teacher_features = teacher_model(dump_image,extra_features=teacher_kd_layers)kd_losses = []for i in range(len(features)):feature = features[i]teacher_feature = teacher_features[i]_, student_channels, _ , _ = feature.shape_, teacher_channels, _ , _ = teacher_feature.shapekd_losses.append(FeatureLoss(student_channels,teacher_channels))
其中hyp[‘xxx_kd_layers’]是用于指定我们要蒸馏的层序号。
为了提取出我们需要的层的特征图,我们还需要对模型推理的代码进行修改,这个放在下一篇,这一篇先把主要流程过一遍。
蒸馏训练
与普通loss一样,在训练中,首先计算蒸馏loss, 然后进行反向传播,区别只是计算蒸馏loss时需要使用teacher模型也对数据进行推理。
if opt.teacher_weights:pred, features = model(imgs, extra_features = student_kd_layers) # forward_, teacher_features = teacher_model(imgs, extra_features = teacher_kd_layers)if "loss_ota" not in hyp or hyp["loss_ota"] == 1 and epoch >= ota_start:loss, loss_items = compute_loss_ota(pred, targets.to(device), imgs)else:loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size# kd lossloss_items = torch.cat((loss_items[0].unsqueeze(0), loss_items[1].unsqueeze(0), loss_items[2].unsqueeze(0), torch.zeros(1, device=device), loss_items[3].unsqueeze(0)))loss_items[-1]*=imgs.shape[0]for i in range(len(features)):feature = features[i]teacher_feature = teacher_features[i]kd_loss, kd_loss_item = kd_losses[i](feature, teacher_feature, targets.to(device), [imgsz,imgsz])loss += kd_lossloss_items[3] += kd_loss_itemloss_items[4] += kd_loss_item
在这里,我们将kd_loss累加到了loss上。计算出总的loss,其他就与普通训练一样了。
结语
这篇文章简述了一下yolov7的蒸馏过程,更多细节将在下一篇中讲述。