3.AlexNet--CNN经典网络模型详解(pytorch实现)

        看博客AlexNet--CNN经典网络模型详解(pytorch实现)_alex的cnn-CSDN博客,该博客的作者写的很详细,是一个简单的目标分类的代码,可以通过该代码深入了解目标检测的简单框架。在这里不作详细的赘述,如果想更深入的了解,可以看另一个博客实现pytorch实现MobileNet-v2(CNN经典网络模型详解) - 知乎 (zhihu.com)。

在这里,直接写AlexNet--CNN的代码。

1.首先建立一个model.py文件,用来写神经网络,代码如下:

#model.pyimport torch.nn as nn
import torchclass AlexNet(nn.Module):def __init__(self, num_classes=1000, init_weights=False):super(AlexNet, self).__init__()self.features = nn.Sequential(  #打包nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55] 自动舍去小数点后nn.ReLU(inplace=True), #inplace 可以载入更大模型nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27] kernel_num为原论文一半nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6])self.classifier = nn.Sequential(nn.Dropout(p=0.5),#全链接nn.Linear(128 * 6 * 6, 2048),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes),)if init_weights:self._initialize_weights()def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1) #展平   或者view()x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') #何教授方法if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)  #正态分布赋值nn.init.constant_(m.bias, 0)

2.下载数据集

DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'

3.下载完后写一个spile_data.py文件,将数据集进行分类

#spile_data.pyimport os
from shutil import copy
import randomdef mkfile(file):if not os.path.exists(file):os.makedirs(file)file = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
mkfile('flower_data/train')
for cla in flower_class:mkfile('flower_data/train/'+cla)mkfile('flower_data/val')
for cla in flower_class:mkfile('flower_data/val/'+cla)split_rate = 0.1
for cla in flower_class:cla_path = file + '/' + cla + '/'images = os.listdir(cla_path)num = len(images)eval_index = random.sample(images, k=int(num*split_rate))for index, image in enumerate(images):if image in eval_index:image_path = cla_path + imagenew_path = 'flower_data/val/' + clacopy(image_path, new_path)else:image_path = cla_path + imagenew_path = 'flower_data/train/' + clacopy(image_path, new_path)print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing barprint()print("processing done!")

之后应该是这样:
在这里插入图片描述

4.再写一个train.py文件,用来训练模型

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time#device : GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)#数据转换data_transform = {#具体是对图像进行各种转换操作,并用函数compose将这些转换操作组合起来#以下操作步骤:# 1.图片随机裁剪为224X224# 2.随机水平旋转,默认为概率0.5# 3.将给定图像转为Tensor# 4.归一化处理"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.getcwd()
image_path = data_root + "/flower_data/"  # flower data set path
train_dataset = datasets.ImageFolder(root=image_path + "/train",transform=data_transform["train"])
train_num = len(train_dataset)# print(train_dataset)
# print(train_dataset[1][0].size())
# print(train_dataset.imgs[0][0])# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)
i=1
# for step, data in enumerate(train_loader, start=0):
#     images, labels = data
#     print(i,"==>",images.shape,labels.shape)
#     i+=1validate_dataset = datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=True,num_workers=0)test_data_iter = iter(validate_loader)
test_image, test_label = test_data_iter.__next__()net = AlexNet(num_classes=5, init_weights=True)net.to(device)
#损失函数:这里用交叉熵
loss_function = nn.CrossEntropyLoss()
#优化器 这里用Adam
optimizer = optim.Adam(net.parameters(), lr=0.0002)
#训练参数保存路径
save_path = './AlexNet.pth'
#训练过程中最高准确率
best_acc = 0.0#开始进行训练和测试,训练一轮,测试一轮
for epoch in range(10):# trainnet.train()    #训练过程中,使用之前定义网络中的dropoutrunning_loss = 0.0t1 = time.perf_counter()for step, data in enumerate(train_loader, start=0):images, labels = data#print("step:     ",images.shape,labels)optimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()# print train processrate = (step + 1) / len(train_loader)a = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")print()print(time.perf_counter()-t1)# validatenet.eval()    #测试过程中不需要dropout,使用所有的神经元acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == val_labels.to(device)).sum().item()val_accurate = acc / val_numif val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/638216.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SVN泄露(ctfhub)

目录 下载安装dvcs-ripper 使用SVN 一、什么是SVN? 使用SVN能做什么? 二、SVN泄露(ctfhub) SVN源代码漏洞的主要原因: 工具准备:dirsearch、dvcs-ripper 网络安全之渗透测试全套工具篇(内…

车载电子电器架构 —— 功能安全开发(首篇)

车载电子电器架构 —— 功能安全开发 我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和事,多看一眼都是你的不对。非必要不费力证明自己…

ACE框架学习2

目录 ACE Service Configurator框架 ACE_Server_Object类 ACE_Server_Repository类 ACE_Server_Config类 ACE Task框架 ACE_Message_Queue类 ACE_TASK类 在开始之前&#xff0c;首先介绍一下模板类的实例化和使用。给出以下代码 //ACCEPTOR代表模板的方法 template <…

【数据结构项目】通讯录

个人主页点这里~ 原文件在gitee里~ 通讯录的实现 基于动态顺序表实现通讯录项目1、功能要求2、代码实现file.hfile.cList.hList.ctest.c 基于动态顺序表实现通讯录项目 准备&#xff1a;结构体、动态内存管理、顺序表、文件操作 1、功能要求 ①能够存储100个人的通讯信息 ②…

在PostgreSQL中如何实现分区表以提高查询效率和管理大型表?

文章目录 解决方案1. 确定分区键2. 创建分区表3. 数据插入与查询4. 维护与管理 示例代码1. 创建父表和子表2. 插入数据3. 查询数据 总结 随着数据量的增长&#xff0c;单一的大型表可能会遇到性能瓶颈和管理难题。PostgreSQL的分区表功能允许我们将一个大型表分割成多个较小的、…

一线实战,一次底层超融合故障导致的Oracle异常恢复

背景概述 某客户数据由于底层超融合故障导致数据库产生有大量的坏块&#xff0c;最终导致数据库宕机&#xff0c;通过数据抢救&#xff0c;恢复了全部的数据。下面是详细的故障分析诊断过程&#xff0c;以及详细的解决方案描述&#xff1a; 故障现象 数据库宕机之后&#xff0c…

Docker - 镜像、容器、仓库

原文地址&#xff0c;使用效果更佳&#xff01; Docker - 镜像、容器、仓库 | CoderMast编程桅杆Docker - 镜像、容器、仓库 提示 这个章节涉及到 Docker 最核心的知识&#xff0c;也是在使用过程中最常使用到的&#xff0c;需要重点学习。 什么是Docker镜像、容器、仓库&…

基于STM32的蓝牙小车的Proteus仿真(虚拟串口模拟)

文章目录 一、前言二、仿真图1.要求2.思路3.画图3.1 电源部分3.2 超声波测距部分3.3 电机驱动部分3.4 按键部分3.5 蓝牙部分3.6 显示屏部分3.7 整体 4.仿真5.软件 三、总结 一、前言 proteus本身并不支持蓝牙仿真&#xff0c;这里我采用虚拟串口的方式来模拟蓝牙控制。 这里给…

python爬虫--------requests案列(二十七天)

兄弟姐们&#xff0c;大家好哇&#xff01;我是喔的嘛呀。今天我们一起来学习requests案列。 一、requests____cookie登录古诗文网 1、首先想要模拟登录&#xff0c;就必须要获取登录表单数据 登录完之后点f12&#xff0c;然后点击network&#xff0c;最上面那个就是登录接口…

详解Java中的五种IO模型

文章目录 前言1、内核空间和用户空间2、用户态和内核态3、上下文切换4、虚拟内存5、DMA技术6、传统 IO 的执行流程 一、阻塞IO模型二、非阻塞IO模型三、IO多路复用模型1、IO多路复用之select2、IO多路复用之epoll3、总结select、poll、epoll的区别 四、IO模型之信号驱动模型五、…

VUE运行找不到pinia模块

当我们的VUE运行时报错Module not found: Error: Cant resolve pinia in时 当我们出现这个错误时 可能是 没有pinia模块 此时我们之要下载一下这个模块就可以了 npm install pinia

RattbitMQ安装

1.RabbitMQ是什么? RabbitMQ是消息队列的一种&#xff0c;生态好&#xff0c;好学习&#xff0c;易于理解&#xff0c;时效性强,支持很多不同语言的客户端,扩展性、可用性都很不错。学习性价比非常高的消息队列&#xff0c;适用于绝大多数中小规模分布式系统。 今天先来简单讲…