第P7周:咖啡豆识别(VGG-16复现)

>- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/rbOOmire8OocQ90QM78DRA) 中的学习记录博客**
>- **🍖 原作者:[K同学啊 | 接辅导、项目定制](https://mtyjkh.blog.csdn.net/)**

一、前期工作

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")             #忽略警告信息device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device



2. 导入数据

import os,PIL,random,pathlibdata_dir = './7-data/'
data_dir = pathlib.Path(data_dir)data_paths  = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]
classeNames

# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸# transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])test_transform = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])total_data = datasets.ImageFolder("./7-data/",transform=train_transforms)
total_data

3. 划分数据集





二、手动搭建VGG-16模型

VGG-16结构说明:

●13个卷积层(Convolutional Layer),分别用blockX_convX表示
●3个全连接层(Fully connected Layer),分别用fcX与predictions表示
●5个池化层(Pool layer),分别用blockX_pool表示

VGG-16包含了16个隐藏层(13个卷积层和3个全连接层),故称为VGG-16

这里,我制作了一个视频来展示VGG-16的传播过程

play

0:00/0:22

倍速

volumeUp

fullscreen

fullscreen

VGG-16网络动画展示

FC7

FE8

FC6

7*7*512

1*1*512

1*1*1000

1*1*512

CONV5

CONV4

14*14*512

CONV3

28*28*512

CONV2

56*56*256

112*112*128

CONVOLUTION+RELU

K同学啊制作

MAX POOLING

CONVI

百度/谷歌/微信搜索:K同学啊

224*224*64

FULLY CONNECTED+RELU

image.png


1. 搭建模型



2. 查看模型详情


三、 训练模型
1. 编写训练函数


2. 编写测试函数

测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器


3. 正式训练

model.train()、model.eval()训练营往期文章中有详细的介绍。

📌如果将优化器换成 SGD 会发生什么呢?请自行探索接下来发生的诡异事件的原因。



四、 结果可视化
1. Loss与Accuracy图

 

TRAINING AND VALIDATION ACCURACY

TRAINING AND

VALIDATION LOSS

1.4

1.0

TRAINING LOSS

1.2

TEST LOSS

1.0

0.8

0.8

0.6

0.6

0.4

0.4

0.2

TRAINING ACCURACY

TEST ACCURACY

0.0

75

35

15

35

25

20

40

30

40

30

output_29_0.png


2. 指定图片进行预测
 

Python复制代码

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

from PIL import Image

classes = list(total_data.class_to_idx)

def predict_one_image(image_path, model, transform, classes):

test_img = Image.open(image_path).convert('RGB')

plt.imshow(test_img) # 展示预测的图片

test_img = transform(test_img)

img = test_img.to(device).unsqueeze(0)

model.eval()

output = model(img)

_,pred = torch.max(output,1)

pred_class = classes[pred]

print(f'预测结果是:{pred_class}')

Python复制代码

1

2

3

4

5

# 预测训练集中的某张照片

predict_one_image(image_path='./7-data/Dark/dark (1).png',

model=model,

transform=train_transforms,

classes=classes)

Plain Text复制代码

1

预测结果是:Dark

25

50

75

100

125

150

175

200

50

100

150

200

output_32_1.png


3. 模型评估
 

Python复制代码

1

2

best_model.eval()

epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)

Python复制代码

1

epoch_test_acc, epoch_test_loss

Plain Text复制代码

1

(0.9916666666666667, 0.035762640996836126)

Python复制代码

1

2

# 查看是否与我们记录的最高准确率一致

epoch_test_acc

Plain Text复制代码

1

0.9916666666666667

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/277795.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MFC画折线图,基于x64系统

由于项目的需要,需要画一个折线图。 传统的Teechart、MSChart、HighSpeedChart一般是只能配置在x86系统下,等到使用x64系统下运行就是会报出不知名的错误,这个地方让人很苦恼。 我在进行配置的过程之中,使用Teechart将x86配置好…

【Qt5】ui文件最后会变成头文件

2023年12月14日,周四下午 我也是今天下午偶然间发现这个的 在使用Qt的uic(User Interface Compiler)工具编译ui文件时,会生成对应的头文件。 在Qt中,ui文件是用于描述用户界面的XML文件,而头文件是用于在…

【C语言(十二)】

数据在内存中的存储 一、整数在内存中的存储 整数的2进制表示方法有三种,即 原码、反码和补码 有符号的整数,三种表示方法均有符号位和数值位两部分,符号位都是用0表示“正”,用1表示“负”,最高位的⼀位是被当做符号…

NNDL 作业10 BPTT

习题6-1P 推导RNN反向传播算法BPTT. 我的推导 和PPT结果对比,可得答案没问题 习题6-2 推导公式(6.40)和公式(6.41)中的梯度. 习题6-3 当使用公式(6.50)作为循环神经网络的状态更新公式时, 分析其可能存在梯度爆炸的原因并给出解决方法&…

【Spring的AOP】Spring的简介、案例与工作流程

文章目录 1. 什么是AOP2. AOP的核心概念3. AOP的入门案例原始代码思路分析第一步:导入坐标第二步:制作连接点(原始操作,Dao接口与实现类)第三步:制作共性功能(通知类与通知)第四步&a…

继续看回溯问题

关卡名 继续看回溯问题 我会了✔️ 内容 1.复习递归和N叉树,理解相关代码是如何实现的 ✔️ 2.理解回溯到底怎么回事 ✔️ 3.掌握如何使用回溯来解决二叉树的路径问题 ✔️ 1 复原IP地址 这也是一个经典的分割类型的回溯问题。LeetCode93.有效IP地址正好由四…

TrustZone之完成器:外围设备和内存

到目前为止,在本指南中,我们集中讨论了处理器,但TrustZone远不止是一组处理器功能。要充分利用TrustZone功能,我们还需要系统其余部分的支持。以下是一个启用了TrustZone的系统示例: 本节探讨了该系统中的关键组件以及它们在TrustZone中的作用。 完成器:外围设备…

概念解读稳定性保障

什么是稳定 百度百科关于稳定的定义: “稳恒固定;没有变动。” 很明显这里的“稳定”是相对的,通常会有参照物,例如 A 车和 B 车保持相同速度同方向行驶,达到相对平衡相对稳定的状态。 那么软件质量的稳定是指什么…

PhotoMaker——通过堆叠 ID 嵌入定制逼真的人像照片

论文网址链接:https://arxiv.org/abs/2312.04461 详情网址链接:PhotoMaker 开源代码网址链接:GitHub - TencentARC/PhotoMaker: PhotoMaker 文本到图像AI生成的最新进展在根据给定文本提示合成逼真的人类照片方面取得了显着进展。然而&#…

UDS DTC老化机制

文章目录 简介基本概念1、操作周期(Operation Cyle)2、错误计数(FDC, Fault Detection Counter)3、确认阈值(Confirmation Threshold)4、老化计数(Aging Counter)5、老化阈值(Aging Threshold) 老化条件非排放 DTC 示例参考 简介 当某个DTC在一定次数的操作循环内,…

蓝桥杯专题-真题版含答案-【扑克牌排列】【放麦子】【纵横放火柴游戏】【顺时针螺旋填入】

Unity3D特效百例案例项目实战源码Android-Unity实战问题汇总游戏脚本-辅助自动化Android控件全解手册再战Android系列Scratch编程案例软考全系列Unity3D学习专栏蓝桥系列ChatGPT和AIGC 👉关于作者 专注于Android/Unity和各种游戏开发技巧,以及各种资源分…

mac电脑html文件 局域网访问

windows html文件 局域网访问 参考 https://blog.csdn.net/qq_38935512/article/details/103271291mac电脑html文件 局域网访问 开发工具vscode 安装vscode插件 Live Server 完成后打开项目的html 右键使用Live Server打开页面 效果如下,使用本地ip替换http://12…