Datawhale AI 夏令营 第五期 CV方向 02进阶

上次的baseline方案,训练的模型,获得分数并不高,DataWhale提供了两个上分的思路

  1. 增加训练数据集
  2. 切换不同模型预训练权重

增加训练集的大小通常可以提高模型的泛化能力,因为更多的数据可以帮助模型学习到更多的特征和模式。但是,越大的数据集,就意味着需要更多的计算资源和时间来训练模型,以及可能出现的过拟合问题。

增加训练数据集

增大数据集的一些方法:

  1. 数据增强: 通过对现有数据进行变换(如旋转、缩放、裁剪、颜色调整等)来增加数据集的多样性。

  2. 合成数据: 使用数据合成技术生成新的训练样本,尤其是在数据稀缺的情况下。

  3. 数据挖掘: 从互联网或公共数据集中收集更多相关数据。

  4. 众包: 利用众包平台收集和标注数据。

  5. 迁移学习: 使用预训练模型作为起点,然后在较小的数据集上进行微调。

  6. 分层抽样: 确保数据集中的每个类别都有足够数量的样本。

  7. 交叉验证: 使用交叉验证来更有效地利用有限的数据,同时评估模型的稳定性。

  8. 正则化技术: 如L1或L2正则化,以减少过拟合的风险。

  9. 早停法: 在验证集上的性能不再提升时停止训练,以避免过拟合。

  10. 调整模型复杂度: 根据数据集的大小调整模型的复杂度,以找到最佳的模型容量。

这里,我们直接从数据集中划分更多的数据作为训练数据,同时,验证集也增大

训练集增大到30

for anno_path, video_path in zip(train_annos[:30], train_videos[:30]):print(video_path)anno_df = pd.read_json(anno_path)cap = cv2.VideoCapture(video_path)frame_idx = 0 while True:ret, frame = cap.read()if not ret:breakimg_height, img_width = frame.shape[:2]frame_anno = anno_df[anno_df['frame_id'] == frame_idx]cv2.imwrite('./yolo-dataset/train/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.jpg', frame)if len(frame_anno) != 0:with open('./yolo-dataset/train/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.txt', 'w') as up:for category, bbox in zip(frame_anno['category'].values, frame_anno['bbox'].values):category_idx = category_labels.index(category)x_min, y_min, x_max, y_max = bboxx_center = (x_min + x_max) / 2 / img_widthy_center = (y_min + y_max) / 2 / img_heightwidth = (x_max - x_min) / img_widthheight = (y_max - y_min) / img_heightif x_center > 1:print(bbox)up.write(f'{category_idx} {x_center} {y_center} {width} {height}\n')frame_idx += 1

验证集

for anno_path, video_path in zip(train_annos[-10:], train_videos[-10:]):print(video_path)anno_df = pd.read_json(anno_path)cap = cv2.VideoCapture(video_path)frame_idx = 0 while True:ret, frame = cap.read()if not ret:breakimg_height, img_width = frame.shape[:2]frame_anno = anno_df[anno_df['frame_id'] == frame_idx]cv2.imwrite('./yolo-dataset/val/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.jpg', frame)if len(frame_anno) != 0:with open('./yolo-dataset/val/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.txt', 'w') as up:for category, bbox in zip(frame_anno['category'].values, frame_anno['bbox'].values):category_idx = category_labels.index(category)x_min, y_min, x_max, y_max = bboxx_center = (x_min + x_max) / 2 / img_widthy_center = (y_min + y_max) / 2 / img_heightwidth = (x_max - x_min) / img_widthheight = (y_max - y_min) / img_heightup.write(f'{category_idx} {x_center} {y_center} {width} {height}\n')frame_idx += 1

切换不同模型预训练权重

先了解一下YOLO系列中常见的不同版本(s, m, l, x)的区别:

  1. YOLO-S (Small): 这是YOLO系列中的小型版本,通常具有较少的参数和较低的计算需求。它适用于资源受限的环境,如移动设备或嵌入式系统,但可能在检测精度上有所牺牲。

  2. YOLO-M (Medium): 中型版本提供了一个平衡点,它比小型版本有更多的参数和更高的计算需求,同时保持了较好的检测精度和速度。

  3. YOLO-L (Large): 大型版本拥有最多的参数和最高的计算需求。它提供了更高的检测精度,但速度可能会慢于小型和中型版本。

  4. YOLO-X (Extra Large): 这是YOLO系列中的超大型版本,它具有最多的参数和最高的计算需求。YOLO-X通常用于需要最高精度的场景,尽管它的速度可能不如其他版本快。

这里选择了YOLOv8s的预训练模型

同时增加训练回合

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"import warnings
warnings.filterwarnings('ignore')from ultralytics import YOLO
# model = YOLO("yolov8n.pt")
model = YOLO("yolov8s.pt")
results = model.train(data="yolo-dataset/yolo.yaml", epochs=30, imgsz=1080, batch=16)

这是baseline的训练日志

这是优化以后的训练日志

可以看到:
泛化能力(dfl_loss)和准确性(cls_loss)都有提高

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/789368.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

luoguP5369 [PKUSC2018] 最大前缀和

题目n<=20 题解 想了半天3位状态的折半,然后发现空间开不下(时间也不太行) 所以放弃思考,直接枚举答案答案是a中的一个集合,设为S;记集合S的和为sum[S] 考虑当S确定时,有多少种方案能使答案恰好为sum[S]。为了处理多种sum相同的情况,记S为从前往后考虑,第一次出现最…

Java Script网页设计案例

本文提供了一个简单的JavaScript网页设计案例,该案例将实现一个动态的待办事项列表(Todo List)。用户可以在页面上添加新的待办事项,标记它们为已完成,以及删除它们。这个案例将使用HTML来构建页面结构,CSS来美化页面,以及JavaScript来添加动态功能。1. JavaScript网页设…

ggml 简介

ggml 是一个用 C 和 C++ 编写、专注于 Transformer 架构模型推理的机器学习库。该项目完全开源,处于活跃的开发阶段,开发社区也在不断壮大。ggml 和 PyTorch、TensorFlow 等机器学习库比较相似,但由于目前处于开发的早期阶段,一些底层设计仍在不断改进中。 相比于 llama.cp…

层序遍历(广度优先搜索)-102

题目描述 给你二叉树的根节点 root ,返回其节点值的 层序遍历 。 (即逐层地,从左到右访问所有节点)。解题思路 这里我们层次遍历我们需要使用到队列这个数据结构,我们依次从根节点开始遍历,我们需要使用一个变量来记录此时我们队列中元素的数量,因为这样我们才知道这一层…

CF1980F1 F2 Field Division

前言 纪念一下独立做出来的 \(2400\) 的题 Easy version 思路 先说 \(Easy\) 版本的 我们走路的方式只有可能是这种样子:(出处:luogu user FiraCode) 不想手绘图了 即对列排序后,所形成的一个行编号上升的序列 所以 \(Easy\) 就很简单了,对于每一列的最大值,如果大于当前前…

一篇文章讲清楚Java中的反射

介绍 每个类都有一个 Class 对象,包含了与类有关的信息。当编译一个新类时,会产生一个同名的 .class 文件,该文件内容保存着 Class 对象。 类加载相当于 Class 对象的加载。类在第一次使用时才动态加载到 JVM 中,可以使用 Class.forName("com.mysql.jdbc.Driver"…

django timezone.now 小了8小时

django.util.timezone.now()原因:setting.py中设置了时区:LANGUAGE_CODE = en-usTIME_ZONE = UTCUSE_I18N = TrueUSE_TZ = True # 若数据库中存储的是UTC时间,但在模板显示的时候,会转成TIME_ZONE所示的本地时间进行显示****将TIME_ZONE时区改为:TIME_ZONE = Asia/Shangha…

枚举实现原理

枚举的定义 在JDK1.5之前,我们要是想定义一些有关常量的内容,例如定义几个常量,表示从周一到周末,一般都是在一个类,或者一个接口中,写类似于如下代码: public class WeekDayConstant {public static final int MONDAY = 0;public static final int TUESDAY = 1;public …

航图中的扇区数据生成

今天简单聊一下机场扇区,最后再介绍一下《风标设计2025》航图模块中的扇区绘制功能。机场扇区是以机场基准点(ARP)或归航台为圆心,半径46km(25nm),外加9km(5nm)缓冲区构成。PBN程序扇区通常以ARP为圆心,传统程序扇区以导航台为圆心。对于中小机场,为了统一和简化扇区划设…

【Azure Policy】添加策略用于审计Azure 网络安全组(NSG)规则 -- 只能特定的IP地址允许3389/22端口访问

问题描述 对Azure上的虚拟机资源,需要进行安全管理。只有指定的IP地址才能够通过RDP/SSH远程到虚拟机上, 有如下几点考虑: 1) 使用Azure Policy服务,扫描订阅中全部的网络安全组(NSG: Network Security Group) 资源 2) 判断入站规则,判断是否是3389, 22端口 3) 判断源地…