pytorch搭建AlexNet网络实现花分类

pytorch搭建AlexNet网络实现花分类

  • 一、AlexNet网络
    • 概述
    • 分析
  • 二、数据集准备
    • 下载
    • 划分训练集和测试集
  • 三、代码
    • model.py
    • train.py
    • predict.py

一、AlexNet网络

概述

在这里插入图片描述
使用Dropout的方式在网络正向传播过程中随机失活一部分神经元,以减少过拟合
在这里插入图片描述

分析

对其中的卷积层、池化层和全连接层进行分析

1,Conv1
注意:图片中用了两块GPU并行计算,上下两组图结构一样。
在这里插入图片描述

  • 输入:input_size = [224, 224, 3]
  • 卷积层:
    kernels = 48 * 2 = 96 组卷积核
    kernel_size = 11
    padding = [1, 2] (左上围加半圈0,右下围加2倍的半圈0)
    stride = 4
  • 输出:output_size = [55, 55, 96]

经 Conv1 卷积后的输出层尺寸为:
在这里插入图片描述
2,Maxpool1
在这里插入图片描述

  • 输入:input_size = [55, 55, 96]
  • 池化层:(只改变尺寸,不改变深度channel)
    kernel_size = 3
    padding = 0
    stride = 2
  • 输出:output_size = [27, 27, 96]

经 Maxpool1 后的输出层尺寸为:
在这里插入图片描述
3,Conv2
在这里插入图片描述

  • 输入:input_size = [27, 27, 96]
  • 卷积层:
    kernels = 128 * 2 = 256 组卷积核
    kernel_size = 5
    padding = [2, 2]
    stride = 1
  • 输出:output_size = [27, 27, 256]

经 Conv2 卷积后的输出层尺寸为:
在这里插入图片描述
4,Maxpool2
在这里插入图片描述

  • 输入:input_size = [27, 27, 256]
  • 池化层:(只改变尺寸,不改变深度channel)
    kernel_size = 3
    padding = 0
    stride = 2
  • 输出:output_size = [13, 13, 256]

经 Maxpool2 后的输出层尺寸为:
在这里插入图片描述
5,Conv3
在这里插入图片描述

  • 输入:input_size = [13, 13, 256]
  • 卷积层:
    kernels = 192* 2 = 384 组卷积核
    kernel_size = 3
    padding = [1, 1]
    stride = 1
  • 输出:output_size = [13, 13, 384]

经 Conv3 卷积后的输出层尺寸为:
在这里插入图片描述
6,Conv4
在这里插入图片描述

  • 输入:input_size = [13, 13, 384]
  • 卷积层:
    kernels = 192* 2 = 384 组卷积核
    kernel_size = 3
    padding = [1, 1]
    stride = 1
  • 输出:output_size = [13, 13, 384]

经 Conv4 卷积后的输出层尺寸为:
在这里插入图片描述
7,Conv5
在这里插入图片描述

  • 输入:input_size = [13, 13, 384]
  • 卷积层:
    kernels = 128* 2 = 256 组卷积核
    kernel_size = 3
    padding = [1, 1]
    stride = 1
  • 输出:output_size = [13, 13, 256]

经 Conv5 卷积后的输出层尺寸为:
在这里插入图片描述
8,Maxpool3
在这里插入图片描述

  • 输入:input_size = [13, 13, 256]
  • 池化层:(只改变尺寸,不改变深度channel)
    kernel_size = 3
    padding = 0
    stride = 2
  • 输出:output_size = [6, 6, 256]

经 Maxpool3 后的输出层尺寸为:
在这里插入图片描述
9,FC1、FC2、FC3
Maxpool3 → (6*6*256) → FC1 → 2048 → FC2 → 2048 → FC3 → 1000
最终的1000可以根据数据集的类别数进行修改。

二、数据集准备

下载

包含 5 中类型的花,每种类型有600~900张图像不等。
https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
在这里插入图片描述

划分训练集和测试集

此数据集不同于 CIFAR10 下载时已经划分完成,需要自行划分。
shift + 右键 打开 PowerShell ,执行 “split_data.py” 分类脚本自动将数据集划分成 训练集train 和 验证集val。
split_data.py 代码如下:

import os
from shutil import copy
import randomdef mkfile(file):if not os.path.exists(file):os.makedirs(file)# 获取 flower_photos 文件夹下除 .txt 文件以外所有文件夹名( 即5种花的类名)
file_path = './flower_photos'
flower_class = [cla for cla in os.listdir(file_path) if ".txt" not in cla] # 创建 训练集train 文件夹,并由5种类名在其目录下创建5个子目录
mkfile('flower_data/train')
for cla in flower_class:mkfile('flower_data/train/'+cla)# 创建 验证集val 文件夹,并由5种类名在其目录下创建5个子目录
mkfile('flower_data/val')
for cla in flower_class:mkfile('flower_data/val/'+cla)# 划分比例,训练集 : 验证集 = 9 : 1
split_rate = 0.1# 遍历5种花的全部图像并按比例分成训练集和验证集
for cla in flower_class:cla_path = file_path + '/' + cla + '/'  # 某一类别花的子目录images = os.listdir(cla_path)		    # iamges 列表存储了该目录下所有图像的名称num = len(images)eval_index = random.sample(images, k=int(num*split_rate)) # 从images列表中随机抽取 k 个图像名称for index, image in enumerate(images):# eval_index 中保存验证集val的图像名称if image in eval_index:					image_path = cla_path + imagenew_path = 'flower_data/val/' + clacopy(image_path, new_path)  # 将选中的图像复制到新路径# 其余的图像保存在训练集train中else:image_path = cla_path + imagenew_path = 'flower_data/train/' + clacopy(image_path, new_path)print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing barprint()print("processing done!")

通过修改 split_data.py 中的路径和文件名称参数,可以实现对其他数据集进行划分。

三、代码

model.py

import torch.nn as nn
import torchclass AlexNet(nn.Module):def __init__(self, num_classes=1000, init_weights=False):super(AlexNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[96, 55, 55]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[96, 27, 27]nn.Conv2d(96, 256, kernel_size=5, padding=(2, 2)),      # output[256, 27, 27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[256, 13, 13]nn.Conv2d(256, 384, kernel_size=3, padding=(1, 1)),     # output[384, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(384, 384, kernel_size=3, padding=(1, 1)),     # output[384, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(384, 256, kernel_size=3, padding=(1, 1)),     # output[256, 13, 13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[256, 6, 6])self.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(256 * 6 * 6, 4096),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Linear(4096, num_classes),)if init_weights:self._initialize_weights()# 前向传播过程def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1)	# 展平后再传入全连接层x = self.classifier(x)return x# 网络权重初始化,实际上 pytorch 在构建网络时会自动初始化权重def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):                            # 若是卷积层nn.init.kaiming_normal_(m.weight, mode='fan_out',   # 用(何)kaiming_normal_法初始化权重nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)                    # 初始化偏重为0elif isinstance(m, nn.Linear):            # 若是全连接层nn.init.normal_(m.weight, 0, 0.01)    # 正态分布初始化nn.init.constant_(m.bias, 0)          # 初始化偏重为0

注:为了加快训练,可以只使用了一半的网络参数,如下所示:

class AlexNet(nn.Module):def __init__(self, num_classes=1000, init_weights=False):super(AlexNet, self).__init__()# 用nn.Sequential()将网络打包成一个模块,精简代码self.features = nn.Sequential(   # 卷积层提取图像特征nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]nn.ReLU(inplace=True), 									# 直接修改覆盖原值,节省运算内存nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6])self.classifier = nn.Sequential(   # 全连接层对图像分类nn.Dropout(p=0.5),			   # Dropout 随机失活神经元,默认比例为0.5nn.Linear(128 * 6 * 6, 2048),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes),)if init_weights:self._initialize_weights()

train.py

  • 数据预处理

在对训练集的预处理,多了随机裁剪和水平翻转这两个步骤。可以起到扩充数据集的作用,增强模型泛化能力。

transforms.RandomResizedCrop(224),       # 随机裁剪,再缩放成 224×224
transforms.RandomHorizontalFlip(p=0.5),  # 水平方向随机翻转,概率为 0.5, 即一半的概率翻转, 一半的概率不翻转
  • 导入和加载数据

不同于 CIFAR10 数据集,花分类数据集并不在 pytorch 的 torchvision.datasets. 中,因此需要用到 datasets.ImageFolder() 来导入。
ImageFolder()返回的对象是一个包含数据集所有图像及对应标签构成的二维元组容器,支持索引和迭代,可作为torch.utils.data.DataLoader的输入。

  • 存储 索引:标签 的字典

为了方便在 predict 时读取信息,将 索引:标签 存入到一个 json 文件中

# 字典,类别:索引 {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
# 将 flower_list 中的 key 和 val 调换位置
cla_dict = dict((val, key) for key, val in flower_list.items())# 将 cla_dict 写入 json 文件中
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:json_file.write(json_str)

class_indices.json 文件内容如下:

{"0": "daisy","1": "dandelion","2": "roses","3": "sunflowers","4": "tulips"
}
  • 完整训练代码
# 导入包
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time# 使用GPU训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),       # 随机裁剪,再缩放成 224×224transforms.RandomHorizontalFlip(p=0.5),  # 水平方向随机翻转,概率为 0.5, 即一半的概率翻转, 一半的概率不翻转transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}# 获取图像数据集的路径
image_path = "./flower_photos/flower_data/"  # flower data_set path# 导入训练集并进行预处理
train_dataset = datasets.ImageFolder(root=image_path + "/train",transform=data_transform["train"])
train_num = len(train_dataset)
# 按batch_size分批次加载训练集
train_loader = torch.utils.data.DataLoader(train_dataset,   # 导入的训练集batch_size=32,   # 每批训练的样本数shuffle=True,    # 是否打乱训练集num_workers=0)   # 使用线程数,在windows下设置为0# 导入验证集并进行预处理
validate_dataset = datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"])
val_num = len(validate_dataset)
# 加载验证集
validate_loader = torch.utils.data.DataLoader(validate_dataset,	# 导入的验证集batch_size=32, shuffle=True,num_workers=0)# 字典,类别:索引 {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
# 将 flower_list 中的 key 和 val 调换位置
cla_dict = dict((val, key) for key, val in flower_list.items())
# 将 cla_dict 写入 json 文件中
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:json_file.write(json_str)net = AlexNet(num_classes=5, init_weights=True)       # 实例化网络(输出类型为5,初始化权重)
net.to(device)                                        # 分配网络到指定的设备(GPU/CPU)训练
loss_function = nn.CrossEntropyLoss()                 # 交叉熵损失
optimizer = optim.Adam(net.parameters(), lr=0.0002)   # 优化器(训练参数,学习率)save_path = './AlexNet.pth'
best_acc = 0.0for epoch in range(10):########################################## train ###############################################net.train()                         # 训练过程中开启 Dropoutrunning_loss = 0.0                  # 每个 epoch 都会对 running_loss  清零time_start = time.perf_counter()    # 对训练一个 epoch 计时for step, data in enumerate(train_loader, start=0):  # 遍历训练集,step从0开始计算images, labels = data   # 获取训练集的图像和标签optimizer.zero_grad()	# 清除历史梯度outputs = net(images.to(device))                 # 正向传播loss = loss_function(outputs, labels.to(device)) # 计算损失loss.backward()                                  # 反向传播optimizer.step()                                 # 优化器更新参数running_loss += loss.item()# 打印训练进度(使训练过程可视化)rate = (step + 1) / len(train_loader)           # 当前进度 = 当前step / 训练一轮epoch所需总stepa = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")print()print('%f s' % (time.perf_counter()-time_start))########################################### validate ###########################################net.eval()    # 验证过程中关闭 Dropoutacc = 0.0  with torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]  # 以output中值最大位置对应的索引(标签)作为预测输出acc += (predict_y == val_labels.to(device)).sum().item()    val_accurate = acc / val_num# 保存准确率最高的那次网络参数if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f \n' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')

predict.py

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json# 预处理
data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load image
img = Image.open("向日葵.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)# read class_indict
try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# create model
model = AlexNet(num_classes=5)
# load model weights
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))# 关闭 Dropout
model.eval()
with torch.no_grad():# predict classoutput = torch.squeeze(model(img))     # 将输出压缩,即压缩掉 batch 这个维度predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/1781.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MyCat01——如何实现MySQL中的主从复制

1 问题 数据对于我们来说是一项最重要的资产,因为数据丢失带来的损失,对于一家公司来说,有时也是毁灭性的。 那么如何确保数据安全,不因断电或系统故障带来数据丢失呢? 当用户增加,对数据库的访问量也随…

【Soft-prompt Tuning for Large Language Models to Evaluate Bias 论文略读】

Soft-prompt Tuning for Large Language Models to Evaluate Bias 论文略读 INFORMATIONAbstract1 Introduction2 Related work3 Methodology3.1 Experimental setup 4 Results5 Discussion & Conclusion总结A Fairness metricsB Hyperparmeter DetailsC DatasetsD Prompt …

【Java】JVM学习(七)

JVM调优 堆空间如何设置 在分代模型中,各分区的大小对GC的性能影响很大。如何将各分区调整到合适的大小,分析活跃数据的大小是很好的切入点。 活跃数据的大小:应用程序稳定运行时长期存活对象在堆中占用的空间大小,也就是Full …

拧螺丝需求:递归算法的极致应用

前言 在一个平平无奇的下午,接到一个需求,需要给公司的中台系统做一个json报文重组的功能。 因为公司的某些业务需要外部数据的支持,所以会采购一些其它公司的数据,而且为了保证业务的连续性,同一种数据会采购多方的数…

Qt QSqlQueryModel详解

背景知识: Qt SQL的API分为不同层: 驱动层 驱动层 对于QT是基于C来实现的框架,该层主要包括QSqlDriver、QSqlDriverCreator、QSqlDriverCreatorbase、QSqlDriverPlugin and QSqlResult。这一层提供了特定数据库和SQL API层之间的底层桥梁…

Servlet(下篇)

哥几个来学 Servlet 啦 ~~ 这个是 Servlet(上篇)的链接, (2条消息) Servlet (上篇)_小枫 ~的博客-CSDN博客https://blog.csdn.net/m0_64247824/article/details/131229873主要讲了 Servlet的定义、Servlet的部署方式、…

C语言-基础语法学习-3 二级指针

目录 二级指针二级指针的定义和声明二级指针的初始化二级指针的使用二级指针和函数参数二级指针和动态内存分配数组指针二维数组二维数组的初始化二维数组与指针二维数组的遍历 二级指针 当涉及到多级指针时,C语言的灵活性和强大的指针功能可以得到充分的发挥。二级…

原生HTML+CSS+JS制作自己的导航主页

如果你想使用原生HTML、CSS和JS制作自己的导航主页&#xff0c;你可以按照以下步骤进行操作&#xff1a; 先看效果图&#xff1a; 创建HTML文件&#xff1a;首先&#xff0c;创建一个新的HTML文件&#xff0c;并在文件中添加基本的HTML结构。你可以使用<!DOCTYPE html>…

R语言复现一篇6分的孟德尔随机化文章

上一期我们对孟德尔随机化做了一个简单的介绍&#xff0c;今天我们来复现一篇6分左右的使用了孟德尔随机化方法的文章&#xff0c;文章的题目是&#xff1a;Mendelian randomization analysis does not reveal a causal influence of mental diseases on osteoporosis&#xff…

基于tensorflow深度学习的猫狗分类识别

&#x1f935;‍♂️ 个人主页&#xff1a;艾派森的个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持&#xff0c;我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f4…

排序算法——归并排序(递归与非递归)

归并排序 以升序为例 文章目录 归并排序基本思想核心步骤递归写法实现代码 非递归处理边界情况实现代码 时间复杂度 基本思想 归并排序是建立在归并操作上的一种有效的排序算法&#xff0c;该算法是采用分治法的一个非常典型的应用&#xff1a;将已有序的子序列合并&#xff…

《C++ Primer》--学习7

顺序容器 容器库概览 迭代器 与容器一样&#xff0c;迭代器有着公共的接口&#xff1a;如果一个迭代器提供某个操作&#xff0c;那么所有提供相同操作的迭代器对这个操作的实现方式都是相同的。 迭代器范围 一个迭代器范围是由一对迭代器表示&#xff0c;两个迭代器分别指向…