基于Pytorch的猫狗图片分类【深度学习CNN】

猫狗分类来源于Kaggle上的一个入门竞赛——Dogs vs Cats。为了加深对CNN的理解,基于Pytorch复现了LeNet,AlexNet,ResNet等经典CNN模型,源代码放在GitHub上,地址传送点击此处。项目大纲如下:
在这里插入图片描述


文章目录

  • 一、问题描述
  • 二、数据集处理
    • 1 损坏图片清洗
    • 2 抽取图片形成数据集
  • 三、图片预处理
    • (1)init 方法
    • (2)getitem方法
    • (3)len方法
    • (4)测试
  • 四、模型
    • 1 LeNet
    • 2 AlexNet模型
  • 五、训练
    • 1 开始训练
    • 2 tensorboard可视化
  • 六、不同模型训练结果分析
    • 1 LeNet模型
      • (1) 数据集数量=1000,无数据增强
      • (2) 数据集数量=4000,无数据增强
      • (3)数据集数量=4000,数据增强
      • (4)数据集=4000,数据增强
      • (5)使用dropout函数抑制过拟合
    • 2 AlexNet模型
    • 3 squeezeNet模型
    • 4 resNet模型
    • 总结
  • 七、预测


一、问题描述

基于训练集数据,训练一个模型,利用训练好的模型预测未知图片中的动物是狗或者猫的概率。

训练集有25,000张图片,测试集12,500 张图片。

数据集下载地址:https://www.kaggle.com/datasets/shaunthesheep/microsoft-catsvsdogs-dataset

截屏2024-02-19 15.56.01

二、数据集处理

1 损坏图片清洗

01_clean.py中,用多种方式来清洗损坏图片:

  1. 判断开头是否有JFIF
  2. 用imghdr库中的imghdr.what函数判断文件类型
  3. 用Image.open(filename).verify()验证图片是否损坏

结果如下:

截屏2022-04-20 下午1.54.15

2 抽取图片形成数据集

由于一万多张图片比较多,并且需要将Cat类和Dog类的图片合在一起并重新命名,方便获得每张图片的labels,所以可以从原图片文件夹复制任意给定数量图片到train的文件夹,并且重命名如下:

截屏2022-04-22 下午3.58.33

程序为:02_data_processing.py.

三、图片预处理

图片预处理部分需要完成:

  1. 对图片的裁剪:将大小不一的图片裁剪成神经网络所需的,我选择的是裁剪为**(224x224)**
  2. 转化为张量
  3. 归一化:三个方向归一化
  4. 图片数据增强
  5. 形成加载器:返回图片数据和对应的标签,利用Pytorch的Dataset包

dataset.py中定义Mydata的类,继承pytorch的Dataset,定义如下三个方法:

(1)init 方法

读取图片路径,并拆分为数据集和验证集(以下代码仅体现结构,具体见源码):

class Mydata(data.Dataset):"""定义自己的数据集"""def __init__(self, root, Transforms=None, train=True):"""进行数据集的划分"""if train:self.imgs = imgs[:int(0.8*imgs_num)]  #80%训练集else:self.imgs = imgs[int(0.8*imgs_num):]  #20%验证集"""定义图片处理方式"""if Transforms is None:normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])self.transforms = transforms.Compose([ transforms.CenterCrop(224), transforms.Resize([224,224]),transforms.ToTensor(), normalize])

(2)getitem方法

对图片处理,返回数据和标签:

 def __getitem__(self, index):return data, label

(3)len方法

返回数据集大小:

    def __len__(self):"""返回数据集中所有图片的个数"""  return len(self.imgs)

(4)测试

实例化数据加载器后,通过调用getitem方法,可以得到经过处理后的 3 × 244 × 244 3\times244\times244 3×244×244的图片数据

if __name__ == "__main__":root = "./data/train"train = Mydata(root, train=True)  #实例化加载器img,label=train.__getitem__(5)    #获取index为5的图片print(img.dtype)print(img.size(),label)   print(len(train))    #数据集大小
#输出
torch.float32
torch.Size([3, 224, 224]) 0
3200

裁剪处理后图片如下所示,大小为224X224

截屏2022-04-22 下午5.28.56

四、模型

模型都放在 models.py中,主要用了一些经典的CNN模型:

  1. LeNet
  2. ResNet
  3. ResNet
  4. SqueezeNet

下面给出重点关注的LeNet模型和AlexNet模型:

1 LeNet

LeNet模型是一个早期用来识别手写数字图像的卷积神经网络,这个名字来源于LeNet论文的第一作者Yann LeCun。LeNet展示了通过梯度下降训练卷积神经网络可以达到手写数字识别在当时最先进的结果,LeNet模型结构图示如下所示:

截屏2022-04-29 下午7.54.44

由上图知,LeNet分为卷积层块全连接层块两个部分,在本项目中我对LeNet模型做了相应的调整

  1. 采用三个卷积层
  2. 三个全连接层
  3. ReLu作为激活函数
  4. 在卷积后正则化
class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.relu = nn.ReLU()self.sigmoid = nn.Sigmoid()#三个卷积层self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3,stride=2,),nn.BatchNorm2d(16),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv2 = nn.Sequential(nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3,stride=2,),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv3 = nn.Sequential(nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,stride=2,),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)#三个全连接层self.fc1 = nn.Linear(3 * 3 * 64, 64)self.fc2 = nn.Linear(64, 10)self.out = nn.Linear(10, 2)   #分类类别为2,def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.shape[0], -1)x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.out(x)return x

调用torchsummary库,可以观察模型的结构、参数:

截屏2022-04-30 上午12.35.15

2 AlexNet模型

2012年,AlexNet横空出世,这个模型的名字来源于论文第一作者的姓名Alex Krizhevsky。AlexNet使用了8层卷积神经网络,由5个卷积层和3个池化Pooling 层 ,其中还有3个全连接层构成。AlexNet 跟 LeNet 结构类似,但使⽤了更多的卷积层和更⼤的参数空间来拟合⼤规模数据集 ImageNet,它是浅层神经⽹络和深度神经⽹络的分界线。

特点:

  1. 在每个卷积后面添加了Relu激活函数,解决了Sigmoid的梯度消失问题,使收敛更快。
  2. 使用随机丢弃技术(dropout)选择性地忽略训练中的单个神经元,避免模型的过拟合(也使用数据增强防止过拟合)
  3. 添加了归一化LRN(Local Response Normalization,局部响应归一化)层,使准确率更高。
  4. 重叠最大池化(overlapping max pooling),即池化范围 z 与步长 s 存在关系 z>s 避免平均池化(average pooling)的平均效应

五、训练

训练在 main.py中,主要是对获取数据、训练、评估、模型的保存等功能的整合,能够实现以下功能:

  1. 指定训练模型、epoches等基本参数
  2. 是否选用预训练模型
  3. 接着从上次的中断的地方继续训练
  4. 保存最好的模型和最后一次训练的模型
  5. 对模型的评估:Loss和Accuracy
  6. 利用TensorBoard可视化

1 开始训练

main.py程序中,设置参数和模型(models.py中可以查看有哪些模型):

截屏2022-04-29 下午11.22.34

在vscode中点击运行或在命令行中输入:

python3 main.py

即可开始训练,开始训练后效果如下:

截屏2022-04-30 上午8.24.14

若程序中断,设置resume参数为True,可以接着上次的模型继续训练,可以非常方便的任意训练多少次

2 tensorboard可视化

在vscode中打开tensorboard,或者在命令行中进入当前项目文件夹下输入

tensorboard --logdir runs

即可打开训练中的可视化界面,可以很方便的观察模型的效果:

截屏2022-04-30 上午8.28.37

如上图所示,可以非常方便的观察任意一个模型训练过程的效果!

六、不同模型训练结果分析

1 LeNet模型

在用LeNet模型训练的过程中,通过调整数据集数量、是否用数据增强等不同的方法,来训练模型,并观察模型的训练效果。

(1) 数据集数量=1000,无数据增强

通过Tensorboard可视化可以观察到:

  1. 验证集准确率(Accuracy)在上升,训练30epoch左右,达到最终**63%**左右的最好效果
  2. 但验证集误差(Loss)也在上升,训练集误差一直下降
  3. 训练集误差接近于0

说明模型在训练集上效果好,验证集上效果不好,泛化能力差,可以推测出模型过拟合了。而这个原因也是比较好推测的,数据集比较少。

截屏2022-04-29 下午8.23.09

(2) 数据集数量=4000,无数据增强

同样过拟合了,但是最后的准确率能达到**68%**左右,说明数据集增加有效果

截屏2022-04-29 下午8.32.01

(3)数据集数量=4000,数据增强

这次数据集数量同上一个一样为4000,但采用了如下的数据增强:

  1. 水平翻转,概率为p=0.5
  2. 上下翻转,概率为p=0.1

我们可以看到这次一开始验证集误差是下降的,说明一开始没有过拟合,但到15个epoch之后验证集误差开始上升了,说明已经开始过拟合了,但最后的准确率在**71%**左右,说明数据增强对扩大数据集有明显的效果。

截屏2022-04-29 下午8.38.00

(4)数据集=4000,数据增强

这次数据集数量为4000,但采用了如下的数据增强:

  1. 水平翻转,概率为p=0.5
  2. 上下翻转,概率为p=0.5
  3. 亮度变化截屏2022-04-29 下午8.48.10

可以看到:

  1. 35个epoch之前,验证集误差呈下降趋势,准确率也一直上升,最高能到75%
  2. 但在35个epoch之后,验证集误差开始上升,准确率也开始下降

说明使用了更强的数据增强之后,模型效果更好了。

截屏2022-04-29 下午8.50.01

(5)使用dropout函数抑制过拟合

本次数据集和数据增强方式同(4),但是在模型的第一个全连接层加入dropout函数。

dropout原理:

训练过程中随机丢弃掉一些参数。在前向传播的时候,让某个神经元的激活值以一定的概率p(伯努利分布)停止工作,这样可以使模型泛化性更强。截屏2022-04-29 下午8.59.39

不使用dropout示意图 使用dropout示意图

这样相当于每次训练的是一个比较"瘦"的模型,更不容易过拟合

加入dropout函数后,训练85个epochs,可以观察到效果十分显著

  1. 验证集的误差总体呈现下降趋势,且最后没有反弹
  2. 训练集误差下降比较慢了!
  3. 准确率一直上升,最后可以达到76%

说明模型最后没有过拟合,并且效果还不错。

截屏2022-04-29 下午9.03.21

2 AlexNet模型

将AlexNet模型参数打印出来:

截屏2022-04-30 上午12.58.58

可以看到AlexNet相比LeNet,参数数目有数量级的上升,而在数据量比较小的情况下,很容易梯度消失,经过反复的调试:

  1. 要在卷积层加入正则化
  2. 优化器选择SGD
  3. 学习率不能过大

才能避免验证集的准确率一直在50%

经过调试,较好的一次结果如下所示,最终准确率能达到78%

截屏2022-04-30 上午1.10.08

3 squeezeNet模型

在后面两个模型中,使用迁移学习的方法。

**迁移学习(Transfer Learning)**是机器学习中的一个名词,是指一种学习对另一种学习> 的影响,或习得的经验对完成其它活动的影响。迁移广泛存在于各种知识、技能与社会规范> 的学习中,将某个领域或任务上学习到的知识或模式应用到不同但相关的领域或问题中。``截屏2022-04-29 下午11.58.32```

使用squeezeNet预训练模型,在迭代16个epoch后,准确率可以达到93%

截屏2022-04-29 下午11.51.43

4 resNet模型

使用resnet50的预训练模型,训练25个epoch后,准确率可以达到98%!

截屏2022-04-30 上午12.12.36

总结

模型测试集预测准确率
LeNet(无数据增强)68%
LeNet(数据增强)75%
LeNet(采用Dropout)76%
Alexnet78%
squeezeNet(迁移学习)93%
resNet98%

七、预测

模型训练好后,可以打开 predict.py对新图片进行预测,给定用来预测的模型和预测的图片文件夹:

 model = LeNet1() # 模型结构modelpath = "./runs/LeNet1_1/LeNet1_best.pth" # 训练好的模型路径checkpoint = torch.load(modelpath)  model.load_state_dict(checkpoint)  # 加载模型参数root = "test_pics"

运行 predict.py 会将预测的图片储存在 output文件夹中,如下图所示:

pre_04_cat

会给出预测的类别和概率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/492219.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

单片机一个32位地址对应多大的存储空间?

文章目录 文字图片 文字 一个地址是4个字节 一个地址对应一个字节的存储空间(无论8位、16位、32位单片机) 学过C语言的都知道:指针就是地址,因此指针也是4个字节 图片 这两张是工作的笔记、主要看第二张,左边是代码&…

第7.1章:StarRocks性能调优——查询分析

目录 一、查看查询计划 1.1 概述 1.2 查询计划树 1.3 查看查询计划的命令 1.3 查看查询计划 二、查看查询Profile 2.1 启用 Query Profile 2.2 获取 Query Profile 2.3 Query Profile结构与详细指标 2.3.1 Query Profile的结构 2.3.2 Query Profile的合并策略 2.…

单链表详解

个人主页:不爱学英文的码字机器-CSDN博客 收录合集:《数据结构》 在本篇博客中,我们将深入探讨单链表的定义、实现和应用。 本篇博客将用C语言实现的单链表进行讲解,通过一段代码一段讲解来逐个详细讲解,深入了解单链表…

Java编程与数据库技术:疫情居家办公的坚实后盾

✍✍计算机毕业编程指导师 ⭐⭐个人介绍:自己非常喜欢研究技术问题!专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目:有源码或者技术上的问题欢迎在评论区一起讨论交流! ⚡⚡ Java、…

基于自适应波束成形算法的matlab性能仿真,对比SG和RLS两种方法

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于自适应波束成形算法的matlab性能仿真,对比SG和RLS两种方法. 2.测试软件版本以及运行结果展示 MATLAB2022a版本运行 3.核心程序 ........................…

字符函数和字符串函数(C语言进阶)(三)

目录 前言 接上篇: 1.7 strtok 1.8 strerror 1.9 字符分类函数 总结 前言 C语言中对字符和字符串的处理是很频繁的,但是c语言本身是没有字符串类型的,字符串通常放在常量字符串中或着字符数组中。 字符串常量适用于那些对它不做修改的字…

某电力铁塔安全监测预警系统案例分享

项目概述 电力铁塔是承载电力供应的重要设施,它的安全性需要得到可靠的保障。但是铁塔一般安装在户外,分布广泛,且有很多安装在偏远地区,容易受到自然、人力的影响和破环。因此需要使用辅助的方法实时监控铁塔的安全状态&#xff…

使用GPT生成python图表

首先,生成一脚本,读取到所需的excel表格 import xlrddata xlrd.open_workbook(xxxx.xls) # 打开xls文件 table data.sheet_by_index(0) # 通过索引获取表格# 初始化奖项字典 awards_dict {"一等奖": 0,"二等奖": 0,"三等…

HarmonyOS—代码Code Linter检查

Code Linter代码检查 Code-Linter针对ArkTS/TS代码进行最佳实践、编程规范方面的检查,目前还会检查ArkTS语法规则。开发者可根据扫描结果中告警提示手工修复代码缺陷,或者执行一键式自动修复,在代码开发阶段,确保代码质量。 检查…

第四节:Vben Admin登录对接后端getUserInfo接口

系列文章目录 第一节:Vben Admin介绍和初次运行 第二节:Vben Admin 登录逻辑梳理和对接后端准备 第三节:Vben Admin登录对接后端login接口 第四节:Vben Admin登录对接后端getUserInfo接口 文章目录 系列文章目录前言一、回顾Vben…

C语言内存管理-栈内存

栈内存 什么东西存储在栈内存中? 环境变量命令行参数局部变量(包括形参)栈内存有什么特点? 空间有限,尤其在嵌入式环境下。因此不可以用来存储尺寸太大的变量。每当一个函数被调用,栈就会向下增长一段&…

Codeforce Monsters Attack!(B题 前缀和)

题目描述: 思路: 本人第一次的想法是先杀血量低的第二次想法是先搞坐标近的第三次想法看到数据量这么大, 我先加个和看看貌似我先打谁都行,由此综合一下, 我们可以把每一个不同的坐标当作一轮从最小的坐标开始&#x…