迁移学习实现图片分类任务

导入工具包

import time
import osimport numpy as np
from tqdm import tqdmimport torch
import torchvision
import torch.nn as nn
import torch.nn.functional as Fimport matplotlib.pyplot as plt
%matplotlib inline# 忽略烦人的红色提示
import warnings
warnings.filterwarnings("ignore")

获取计算硬件

# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

图片预处理

from torchvision import transforms# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

这里对train训练集和text集的处理不同,几个transforms的操作通过compose进行整合。

载入图片分类数据集

# 数据集文件夹路径
dataset_dir = 'fruit30_split'train_path = os.path.join(dataset_dir, 'train')
test_path = os.path.join(dataset_dir, 'val')
print('训练集路径', train_path)
print('测试集路径', test_path)from torchvision import datasets# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)print('训练集图像数量', len(train_dataset))
print('类别个数', len(train_dataset.classes))
print('各类别名称', train_dataset.classes)print('测试集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)

datasets下的ImageFolder,可以直接构建数据集。

类别与索引号一一对应

class_names = train_dataset.classes
n_class = len(class_names)# 映射关系:类别 到 索引号
train_dataset.class_to_idx

定义数据加载器Dataloader,dataloader用于给模型喂数据。

from torch.utils.data import DataLoaderBATCH_SIZE = 32# 训练集的数据加载器
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)# 测试集的数据加载器
test_loader = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=4)

查看一个batch的图像与标注

# DataLoader 是 python生成器,每次调用返回一个 batch 的数据
images, labels = next(iter(train_loader))images. Shape
#torch.Size([32, 3, 224, 224])
labels
#tensor([11, 19,  3, 25, 29, 13, 21, 18, 11,  1, 13, 15, 13,  0, 15, 25,  0,  7,11, 10,  9,  6, 26,  2, 11, 10, 29, 29, 15,  8, 19,  8])

迁移学习范式

导入训练所用的工具包

from torchvision import models
import torch.optim as optim
model = models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与当前数据集类别数对应
# 新建的层默认 requires_grad=True
model.fc = nn.Linear(model.fc.in_features, n_class)
model.fc
Linear(in_features=512, out_features=30, bias=True)
# 只微调训练最后一层全连接层的参数,其它层冻结
optimizer = optim.Adam(model.fc.parameters())

采用第一种迁移学习的方式,优化器采用的是Adam的优化器。

训练配置

model = model.to(device)# 交叉熵损失函数
criterion = nn.CrossEntropyLoss() # 训练轮次 Epoch
EPOCHS = 20

模拟一个batch的训练

这里着重注意反向传播三部曲

# 反向传播“三部曲”
optimizer.zero_grad() # 清除梯度
loss.backward() # 反向传播
optimizer.step() # 优化更新

 运行完整训练

# 遍历每个 EPOCH
for epoch in tqdm(range(EPOCHS)):model. Train() #每次开始前将模型设置为训练模式for images, labels in train_loader:  # 获取训练集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images)           # 前向预测,获得当前 batch 的预测结果loss = criterion(outputs, labels) # 比较预测结果和标注,计算当前 batch 的交叉熵损失函数optimizer.zero_grad()loss.backward()                   # 损失函数对神经网络权重反向传播求梯度optimizer.step()                  # 优化更新神经网络权重

在测试集上进行初步测试

model.eval() #模型设置为测试模式
with torch.no_grad(): #不再回传梯度correct = 0total = 0for images, labels in tqdm(test_loader): # 获取测试集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images)              # 前向预测,获得当前 batch 的预测置信度_, preds = torch.max(outputs, 1)     # 获得最大置信度对应的类别,作为预测结果total += labels.size(0)correct += (preds == labels).sum()   # 预测正确样本个数,如果预测类别等于标注类别print('测试集上的准确率为 {:.3f} %'.format(100 * correct / total))

保存模型

torch.save(model, 'checkpoint/fruit30_pytorch_C1.pth')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/443622.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【实战】使用Helm在K8S集群安装MySQL主从

文章目录 前言技术积累什么是HelmStorageClass使用的工具版本 helm 安装 MySQL 1主2从1. 添加 bitnami 的仓库2. 查询 MySQL 资源3. 拉取 MySQL chart 到本地4. 对chart 本地 values-test.yaml 修改5. 对本地 templates 模板 修改6. 安装 MySQL 集群7. 查看部署的 MySQL 集群8.…

11.Ubuntu

目录 1. 什么是Ubuntu 1.1. 概述 1.2. Ubuntu版本简介 1.2.1. 桌面版 1.2.2. 服务器版 2. 部署系统 2.1. 新建虚拟机 2.2. 安装系统 2.3. 部署后的设置 2.3.1. 设置root密码 2.3.2. 关闭防火墙 2.3.3. 启用允许root进行ssh 2.3.4. 安装所需软件 2.3.5. 制作快照 …

经典左旋,指针面试题

今天给大家带来几道面试题! 实现一个函数,可以左旋字符串中的k个字符。 例如: ABCD左旋一个字符得到BCDA ABCD左旋两个字符得到CDAB 我们可以先自己自行思考,下面是参考答案: 方法一: #define _CRT_SEC…

linux 使用命令创建mysql账户

目录 前言创建步骤 前言 mysql默认有一个root用户,这个账户权限太大了,用起来不太安全,我们通常是重新那家一个账户用于一般的数据库操作,下面介绍如何通过命令创建一个mysql账户。 创建步骤 登录mysql mysql -u root -p输入roo…

如何发布自己的npm包:

1.创建一个打包组件或者库: 安装weback: 打开项目: 创建webpack.config.js,创建src目录 打包好了后发现两个js文件都被压缩了,我们想开发使用未压缩,生产使用压缩文件。 erserPlugin:(推荐使用…

KAFKA监控方法以及核心指标

文章目录 1. 监控指标采集1.1 部署kafka_exporter1.2 prometheus采集kafka_exporter的暴露指标1.3 promethues配置告警规则或者配置grafana大盘 2. 核心告警指标2.1 broker核心指标2.2 producer核心指标2.3 consumer核心指标 3. 参考文章 探讨kafka的监控数据采集方式以及需要关…

svn 安装路径

SVN客户端安装(超详细) 一、SVN客户端安装 1、下载安装包地址:https://tortoisesvn.net/downloads.html 此安装包是英文版的,还可以下载一个语言包,在同界面的下方 一直点击下一步,直到弹出选择红框 然…

电子信息找工作选fpga还是嵌入式?

电子信息找工作选fpga还是嵌入式? 在开始前我分享下我的经历,刚入行时遇到一个好公司和师父,给了我机会,两年时间从3k薪资涨到18k的, 我师父给了一些嵌入式学习方法和资料,让我不断提升自己,感…

【产业实践】使用YOLO V5 训练自有数据集,并且在C# Winform上通过onnx模块进行预测全流程打通

使用YOLO V5 训练自有数据集,并且在C# Winform上通过onnx模块进行预测全流程打通 效果图 背景介绍 当谈到目标检测算法时,YOLO(You Only Look Once)系列算法是一个备受关注的领域。YOLO通过将目标检测任务转化为一个回归问题,实现了快速且准确的目标检测。以下是YOLO的基…

免费的ChatGPT网站(7个)

还在为找免费的chatGPT网站或者应用而烦恼吗?博主归纳总结了7个国内非常好用,而且免费的chatGPT网站,AI语言大模型,我们都来接触一下吧。 免费!免费!免费!...,建议收藏保存。 1&…

使用ChatGPT学习大象机器人六轴协作机械臂mechArm

引言 我是一名机器人方向的大学生,近期学校安排自主做一个机器人方面相关的项目。学校给我们提供了一个小型的六轴机械臂,mechArm 270M5Stack,我打算使用ChatGPT让它来辅助我学习如何使用这个机械臂并且做一个demo。 本篇文章将记录我是如何使…

DDD学习使用

简介 DDD(Domain-Driven Design):领域驱动设计。 Eric Evans “领域驱动设计之父” DDD不是架构,而是一种方法论(Methodology)微服务架构从一出来就没有很好的理论支撑如何合理的划分服务边界,人们常常为服务要划分多…