PyTorch搭建AlexNet训练集

本次项目是使用AlexNet实现5种花类的识别。

训练集搭建与LeNet大致代码差不多,但是也有许多新的内容和知识点。

1.导包,不必多说。

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt   # 从 matplotlib.pyplot 导入 imshow 函数
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time

2.指定设备

device函数用来指定在训练过程中所使用的设备:如果有可用的GPU,那么使用第一块GPU,如果没有就默认使用cpu。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

 3.数据预处理函数

单独定义出来,当key为“train”或为“val”时,返回数据集要使用的一系列预处理方法。

data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),   # 把图片重新裁剪为224*224transforms.RandomHorizontalFlip(),  # 水平方向随机翻转transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

4.获取数据集的路径

os.getcwd()方法获取当前文件所在的目录

os.path.join()方法将当前路径与上两级路径链接起来

image_path:获取到flower_data所在路径

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = data_root + "/data_set/flower_data"
# train set
train_dataset = datasets.ImageFolder(root=image_path + "/train", # 获取训练集的路径transform=data_transform["train"])  # 训练预处理
train_num = len(train_dataset)  # 打印训练集有多少张照片

5.加载数据集分类文件 

{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflower': 3, 'tulips': 4} :数据集共分为五类
flower_list = train_dataset.class_to_idx  获取分类的名称所对应的索引值
cla_dict = dict((val, key) for key, val in flower_list.items())  将字典中键与值的位置对换

为什么要换位置

=>这样在预测后可以直接通过值给到我们最后的测试类别
json_str = json.dumps(cla_dict, indent=4) :将字典编码成json格式
with open('class_indices,json', 'w') as json_file:
        json_file.write(json_str)  :将键值对保存到json文件中,方便后续在预测时读取信息

下面是生成的json文件

# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflower': 3, 'tulips': 4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# 把文件写入json文件
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices,json', 'w') as json_file:json_file.write(json_str)

 6.载入测试集

代码大致与LeNet网络差不多,载入测试集的图片路径需要自己定义并进行预处理。

在使用matplotlib查看图片时,注意修改为batch_size=4,shuffle=True参数。

batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)
#
validate_dataset = datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"])
val_num = len(validate_dataset)
validata_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size,shuffle=False, num_workers=0)

6.5 查看测试数据 

在原来的基础上做了修改,原来使用test_data_iter.next() 调用方式,但是test_data_iter.next() 调用方式已经过时。在 Python 中,迭代器的 next() 方法应该直接调用,而不是使用 iter.next() 的形式。所以应该使用 next(test_data_iter) 代替 test_data_iter.next()

!! 再使用 imshow () 函数时调用 mayplotlib.pyplot 库!

test_data_iter = iter(validate_loader)
test_image, test_label = next(test_data_iter)# 查看图片
def imshow(img):img = img / 2 + 0.5nping = img.numpy()plt.imshow(np.transpose(nping, (1, 2, 0)))plt.show()# print labels
print(' '.join('%5s' % str(cla_dict[test_label[j].item()]) for j in range(4)))
# show images
imshow(utils.make_grid(test_image))

 查看数据的结果

预测结果分别是 蒲公英、向日葵、 郁金香、郁金香

这个图片像素不高,不是很清楚,我专门去测试集的数据中找到了原图片,都预测对了。

7. 模型实例化

net = AlexNet(num_classes=5, init_weights=True)
net.to(device)  # 将网络指定到运行设备上
loss_function = nn.CrossEntropyLoss()  # 定义损失函数,针对多类别的损失交叉函数optimizer = optim.Adam(net.parameters(), lr=0.0002) # 设置优化器

8.开始训练模型

通过一整个for循环来实现模型训练,基本过程与LeNet网络实现差不多 (ps:PyTorch搭建LeNet训练集详细实现-CSDN博客)第五部分),新出现的代码做了注释解释。

save_path = './AlexNet.pth'
best_acc = 0.0
for epoch in range(10):net.train()  # 管理dropout方法running_loss = 0.0t1 = time.perf_counter()  # 调time包获取训练过程中的测试时间for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()rate = (step+1) / len(train_loader)  # 训练进度a = "*" * int(rate*50)b = "." *int((1-rate)*50)print("\rtrain loss: (:3.0f)%[()->:.3f)".format(int(rate * 100), a, b, loss), end="")print()print(time.perf_counter()-t1)# 进行验证测试集net.eval()acc = 0.0with torch.no_grad():  # 禁止对pytorch对参数的追踪for data_test in validate_loader:test_images, test_labels = data_testoutputs = net(test_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == test_labels.to(device)).sum().item()accurate_test = acc / val_num  # 计算准确率if accurate_test > best_acc:best_acc = accurate_test   # 如果新的准确率大于最好的那个,将新的赋值给best_acctorch.save(net.state_dict(), save_path)  # 保存路径print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, running_loss / step, acc / val_num))
print("Finished Training")

9.运行代码,查看结果

后面的准确率达到了60%多,还可以,我感觉我用GPU跑的还挺慢的,五十多秒,但是比用cpu跑的快。

全部代码

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib as plt
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import timedevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = data_root + "/data_set/flower_data"
# train set
train_dataset = datasets.ImageFolder(root=image_path + "/train",transform=data_transform["train"])
train_num = len(train_dataset)# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflower': 3, 'tulips': 4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# 把文件写入接送文件
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices,json', 'w') as json_file:json_file.write(json_str)batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)
#
validate_dataset = datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=4,shuffle=True, num_workers=0)# test_data_iter = iter(validate_loader)
# test_image, test_label = next(test_data_iter)
#
# # 查看图片
# def imshow(img):
#     img = img / 2 + 0.5
#     nping = img.numpy()
#     plt.imshow(np.transpose(nping, (1, 2, 0)))
#     plt.show()
# # print labels
# print(' '.join('%5s' % str(cla_dict[test_label[j].item()]) for j in range(4)))
# # show images
# imshow(utils.make_grid(test_image))net = AlexNet(num_classes=5, init_weights=True)
net.to(device)
loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0002)save_path = './AlexNet.pth'
best_acc = 0.0
for epoch in range(10):net.train()running_loss = 0.0t1 = time.perf_counter()for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()rate = (step+1) / len(train_loader)a = "*" * int(rate*50)b = "." *int((1-rate)*50)print("\rtrain loss: (:3.0f)%[()->:.3f)".format(int(rate * 100), a, b, loss), end="")print()print(time.perf_counter()-t1)net.eval()acc = 0.0with torch.no_grad():for data_test in validate_loader:test_images, test_labels = data_testoutputs = net(test_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == test_labels.to(device)).sum().item()accurate_test = acc / val_numif accurate_test > best_acc:best_acc = accurate_testtorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, running_loss / step, acc / val_num))
print("Finished Training")

学习碎碎念:

学习的道路上总会是遇到困难和麻烦的,不要心急,不要烦躁,一步一步的解决问题,慢慢来总会好的!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/539916.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【STM32学习】基本定时器,输出比较模式,基本参数

1、概述 此项功能是用来控制一个输出波形,或者指示一段给定的的时间已经到时。 如输出PWM信号时,可用这个模式。 2、输出比较初始化函数,基本参数 以上函数是用来配置输出比较模块的,每个函数对应一个定时器的通道,配…

LVGL移植到ARM开发板(GEC6818开发板)

LVGL移植到ARM开发板(GEC6818开发板) 一、LVGL概述 LVGL(Light and Versatile Graphics Library)是一个开源的图形用户界面库,旨在提供轻量级、可移植、灵活和易于使用的图形用户界面解决方案。 它适用于嵌入式系统…

自然语言处理实验2 字符级RNN分类实验

实验2 字符级RNN分类实验 必做题: (1)数据准备:academy_titles.txt为“考硕考博”板块的帖子标题,job_titles.txt为“招聘信息”板块的帖子标题,将上述两个txt进行划分,其中训练集为70%&#xf…

概率论与数理统计(随机事件与概率)

1随机事件与概率 1.1随机事件及其运算规律 1.1.1运算 交换律结合律分配律德摩根律 1.2概率的定义及其确定方法 1.2.1概率的统计定义 频率 设在 n 次试验中,事件 A 发生了(A)次,则称为事件 A 发生的频率。 1.2.2概率的统计定义 在一组恒定不变的条…

GPT-SoVITS开源音色克隆框架的训练与调试

GPT-SoVITS开源框架的报错与调试 遇到的问题解决办法 GPT-SoVITS是一款创新的跨语言音色克隆工具,同时也是一个非常棒的少样本中文声音克隆项目。 它是是一个开源的TTS项目,只需要1分钟的音频文件就可以克隆声音,支持将汉语、英语、日语三种…

vscode 导入前端项目

vscode 导入前端项目 导入安装依赖 运行 参考vscode 下载 导入 安装依赖 运行 在前端项目的终端中输入npm run serve

KKVIEW: 远程控制软件哪个好用

远程控制软件哪个好用 随着科技的发展和工作方式的改变,远程控制软件越来越受到人们的关注和需求。无论是在家中远程办公,还是技术支持人员为远程用户提供帮助,选择一款高效稳定的远程控制软件至关重要。在众多选择中,有几款远程…

【数学建模】线性规划

针对未来可能的数学建模比赛内容,我对学习的内容做了一些调整,所以先跳过灰色关联分析和模糊综合评价的代码,今天先来了解一下运筹规划类——线性规划模型。 背景: 某数学建模游戏有三种题型,分别是A,B&am…

【AI论文阅读笔记】ResNet残差网络

论文地址:https://arxiv.org/abs/1512.03385 摘要 重新定义了网络的学习方式 让网络直接学习输入信息与输出信息的差异(即残差) 比赛第一名1 介绍 不同级别的特征可以通过网络堆叠的方式来进行丰富 梯度爆炸、梯度消失解决办法:1.网络参数的初始标准化…

微博热搜榜单采集,微博热搜榜单爬虫,微博热搜榜单解析,完整代码(话题榜+热搜榜+文娱榜和要闻榜)

文章目录 代码1. 话题榜2. 热搜榜3. 文娱榜和要闻榜 过程1. 话题榜2. 热搜榜3. 文娱榜和要闻榜 代码 1. 话题榜 import requests import pandas as pd import urllib from urllib import parse headers { authority: weibo.com, accept: application/json, text/pl…

jdk版本规则看这里

Java Development Kit (JDK) 的版本号是由几个不同的数字和有时的字母组合来定义的,这些数字和字母表达了版本的不同层面。下面是 JDK 版本号的一般结构和它们各自的含义: JDK 版本号的组成 主版本号 - 表示主要的发布版本。例如,在 JDK 8 或…

使用 WXT 开发浏览器插件(上手使用篇)

WXT (https://wxt.dev/), Next-gen Web Extension Framework. 号称下一代浏览器开发框架. 可一套代码 (code base) 开发支持多个浏览器的插件. 上路~ WXT 提供了脚手架可以方便我们快速进行开发,但是我们得先安装好环境依赖,这里我们使用 npm, 所以需要…