1. 写在前面
YOLOv9的Loss计算与YOLOv8如出一辙,仅存在略微的差异。多说一句,数据的预处理和导入方式都是一样的。因此如果你已经对YOLOv8了解的比较透彻,那么对于YOLOv9你也只是需要多关注网络结构就可以。
YOLOv9本身也是Anchor-Free的,同样采用了解耦头。因此其损失计算的关键同样在于“对齐”,即通过TAL方法实现Ground Truth与Pred grid中的Cell进行对齐,然后计算Loss。
2. TAL
当前,基于CNN的目标检测算法基本都是在网络的后期阶段从特征图(Feature Map)中利用检测头进行分类和回归。YOLO系列也都遵循这个法则。
不同于早期的YOLOv3、TOLOv5等网络,YOLOX、YOLOv8以及本篇YOLOv9均采用了解耦检测头(Decoupled-Head),即将分类和回归视作两个独立的分支。
对Decoupled-Head类型的算法基本都有这么一种观念,就是不同任务间,用于分类和定位的Anchor(或乘坐Cell,指分类和回归前的特征图中的一个单元)往往不一致,需要经过对齐操作。所谓对齐,简单来看就是计算Loss时,将分类与回归所使用的预测Cell进行统一化。
TAL(Task Align Learning),最早出现在论文Task-aligned One-stage Object Detection中。改论文提供一种思想,即通过构建一种“对齐度量”,来统一分类和回归的Anchor,进而实现最终在推理时,获得一个更高得分的分类框以及更准确的定位框系数。
更通俗地讲,TAL就是给Feature Map中的每一个Cell(或称作Anchor)分配Ground Truth框。在这种前提下,有的Cell能够分配到Ground Truth框,有的Cell分配不到GT框。根据Feature Map与GT的分配情况,构建用于Loss计算的target_labels、target_bboxes和target_scores。
TAL的关键是构建一个对齐度量Align Metric,对齐度量的计算包含两个部分,首先是构建每个GT对应每个Cell的分类得分矩阵(张量)s,然后构建GT框与预测框的两两IOU(CIOU)u。之后通过如下计算,获得对应的对齐矩阵(Tensor),是为Align Metric。
其中s的获得依赖于预测的分类置信度信息,IOU信息依赖于GT与预测边框信息。二者结合,获得对齐度量矩阵。
3. 选取Cell
从早期的YOLO系列我们可知,并不是Feature Map上所有的Cell都参与计算Loss,因此选取哪些Cell参与计算也是关键的一环。
YOLOv9中,选取的是中心点落在GT框内的Cell参与Loss的计算。如下如所示,所有中心点位于GT框(蓝色)区域内的Cell(黄色标注)均参与计算Loss。
4. topK及二次筛选
标记与某个GT相匹配的Cell,获得shape(bs, n_max_boxes, num_total_anchors)的Tensor,标记的是用于预测某一个GT的Cell。如果某一个Cell参与预测某一个GT,那么该Cell的位置上被置1,否则置0。
需要注意的是,一个Cell只能用于预测一个GT。
官方代码工程中的select_topk_candidates函数即是完成这个任务。
之后经过IOU最大筛选后,会获得每一个Cell所匹配的GT情况。
5. 部分源代码解读
源代码中,比较难理解的就是TAL中计算对齐度量的部分,即如何获得表征GT与Pred之间的对齐关系Align Metric。
(1)get_box_metrics
该函数式TAL的关键,通过该函数我们将获得两个用于对齐的关系矩阵(Tensor),分别是align_metric和overlaps。前者表示每一个GT与每一个Cell(对应网络输出的Feature Map)的匹配得分,overlaps则时GT与Cell的IOU(CIOU)。
我们通过如下计算方式获得align_metric。
这里的align_metric是一个对于预测的边界框和真实边界框进行对齐评估的度量。
这里:
bbox_scores代表边界框的得分,通常反映了模型对于其包含目标的置信度。
overlaps代表预测边界框与真实边界框之间的重叠度,常用IoU(Intersection over Union)来衡量。
self.alpha和self.beta是控制bbox_scores和overlaps在对齐度量中权重的超参数。
通过将边界框得分和重叠度的权重组合,align_metric提供了一个综合指标,用于评估预测边界框的质量。具体来说:
当alpha较大时,模型更重视边界框得分,即模型对自身预测的置信度。
当beta较大时,模型更重视与真实边界框的重叠度,即预测的准确性。
(2)select_candidates_in_gts
收集和标记那些中心点位于GT内的Cell。
我们需要知道,YOLOv9和YOLOv8类似,网络推导出来的是预测框的左上和右下相对于Cell中心点的距离(这一点与YOLOv3、YOLOv5不同)。