现代卷积网络实战系列4:PyTorch从零构建VGGNet训练MNIST数据集

🌈🌈🌈现代卷积网络实战系列 总目录

本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

1、MNIST数据集处理、加载、网络初始化、测试函数
2、训练函数、PyTorch构建LeNet网络
3、PyTorch从零构建AlexNet训练MNIST数据集
4、PyTorch从零构建VGGNet训练MNIST数据集
5、PyTorch从零构建GoogLeNet训练MNIST数据集
6、PyTorch从零构建ResNet训练MNIST数据集

8、VGGNet

2014年,牛津大学计算机视觉组(Visual Geometry Group)和Google DeepMind公司的研究员一起研发出了新的深度卷积神经网络:VGGNet,并取得了ILSVRC2014比赛分类项目的第二名(第一名是GoogLeNet,也是同年提出的).论文下载 Very Deep Convolutional Networks for Large-Scale Image Recognition。论文主要针对卷积神经网络的深度对大规模图像集识别精度的影响,主要贡献是使用很小的卷积核(3×3)构建各种深度的卷积神经网络结构,并对这些网络结构进行了评估,最终证明16-19层的网络深度,能够取得较好的识别精度。 这也就是常用来提取图像特征的VGG-16和VGG-19。

VGG可以看成是加深版的AlexNet,整个网络由卷积层和全连接层叠加而成,和AlexNet不同的是,VGG中使用的都是小尺寸的卷积核(3×3)。
我这里使用的是VGG-16,但是又因为这个系列全部是处理MNIST数据集的,所以我这里的VGG网络只用了3个VGG块,FC也减少了很多参数。

9、VGGNet网络架构

在这里插入图片描述

VGGNet(
 (vgg1): VGGBlock(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU(inplace=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU(inplace=True)
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu3): ReLU(inplace=True)
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
 (vgg2): VGGBlock(
  (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU(inplace=True)
  (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU(inplace=True)
  (conv3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu3): ReLU(inplace=True)
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
 )
 (vgg3): VGGBlock(
  (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU(inplace=True)
  (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU(inplace=True)
  (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu3): ReLU(inplace=True)
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
 )
(classifier): Sequential(
  (0): Linear(in_features=12544, out_features=1024, bias=True)
  (1): ReLU(inplace=True)
  (2): Linear(in_features=1024, out_features=512, bias=True)
  (3): ReLU(inplace=True)
  (4): Linear(in_features=512, out_features=10, bias=True)
 )
)

VGG实际上就是很简单,主要是由VGG块组成:
前两组卷积形式一样,每组都是:conv-relu-conv-relu-pool
中间三组卷积形式一样,每组都是:conv-relu-conv-relu-conv-relu-pool
最后分类的三个全连接层:fc-relu-dropout-fc-relu-dropout-fc-softmax

10、PyTorch构建VGGBlock

class VGGBlock(nn.Module):def __init__(self, in_channel, out_channel, num_conv):super(VGGBlock, self).__init__()self.num_conv = num_convself.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1)self.relu1 = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1)self.relu2 = nn.ReLU(inplace=True)self.conv3 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1)self.relu3 = nn.ReLU(inplace=True)self.maxpool1 = nn.MaxPool2d(3, stride=2, padding=1)def forward(self, x):x = self.relu1(self.conv1(x))x = self.relu2(self.conv2(x))if self.num_conv==3:x = self.relu3(self.conv3(x))else:x = self.maxpool1(x)return x

11、PyTorch构建VGGNet

class VGGNet(nn.Module):def __init__(self, num_classes):super(VGGNet, self).__init__()self.vgg1 = VGGBlock(1,64,2)self.vgg2 = VGGBlock(64,128,2)self.vgg3 = VGGBlock(128,256,3)self.classifier = nn.Sequential(nn.Linear(256 * 7 * 7, 1024),nn.ReLU(inplace=True),nn.Linear(1024, 512),nn.ReLU(inplace=True),nn.Linear(512, num_classes))def forward(self, x):x = self.vgg1(x)x = self.vgg2(x)x = self.vgg3(x)x = x.reshape(x.shape[0], -1)x = self.classifier(x)return x

D:\conda\envs\pytorch\python.exe A:\0_MNIST\train.py

Reading data…
train_data: (60000, 28, 28) train_label (60000,)
test_data: (10000, 28, 28) test_label (10000,)

Initialize neural network
test loss 2303.1
test accuracy 10.1%

epoch step 1
training time 8.9s
training loss 204.3
test loss 39.6
test accuracy 98.8%

epoch step 2
training time 8.3s
training loss 48.8
test loss 39.7 test
accuracy 98.8%

epoch step 3
training time 8.1s
training loss 35.9
test loss 26.4
test accuracy 99.1%

Training finished
3 epoch training time 25.4s
One epoch average training time 8.5s

进程已结束,退出代码为 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/120795.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

性能测试工具 — JMeter

一、JMeter准备工作 1、JMeter介绍 Apache JMeter 应用程序是开源软件,是一个 100% 纯 Java 应用程序。用于测试Web应用程序、API和其他网络协议的性能。它具有以下特点: 1. 开源免费:JMeter是Apache软件基金会下的一个开源项目&#xff0…

HashMap常见面试题

简介 HashMap最早出现在JDK1.2中,底层基于散列算法实现。HashMap 允许 null 键和 null 值,是非线程安全类,在多线程环境下可能会存在问题。 1.8版本的HashMap数据结构: 为什么有的是链表有的是红黑树? 默认链表长度大…

京东大型API网关实践之路

概述 1、背景 京东作为电商平台,近几年用户、业务持续增长,访问量持续上升,随着这些业务的发展,API网关应运而生。 API网关,就是为了解放客户端与服务端而存在的。对于客户端,使开放给客户端的接口标准统…

stable diffusion模型评价框架

GhostReview:全球第一套AI绘画ckpt评测框架代码 - 知乎大家好,我是_GhostInShell_,是全球AI绘画模型网站Civitai的All Time Highest Rated (全球历史最高评价) 第二名的GhostMix的作者。在上一篇文章,我主要探讨自己关于ckpt的发展方向的观点…

ESP8266使用记录(四)

放上最终效果 ESP8266&Unity游戏 整合放进了坏玩具车遥控器里 最终只使用了mpu6050的yaw数据,因为roll值漂移…… 使用了https://github.com/ElectronicCats/mpu6050 整个流程 ESP8266取MPU6050数据,处理后通过udp发送给Unity显示出来 MPU6050_Z…

ElementUI之首页导航与左侧菜单

目录 一、Mock 1.1 什么是Mock.js 1.2 安装与配置 1.2.1 安装mock.js 1.2.2 引入mock.js 1.3 mock.js使用 1.3.1 定义测试数据文件 1.3.2 mock拦截Ajax请求 1.3.3 界面代码优化 二、总线 2.1 定义 2.2 类型分类 2.3 前期准备 2.4 配置组件与路由关系 2.4.1 配置…

基于PHP+MySQL的养老院管理系统

摘要 随着21世纪互联网时代的兴起,我们见证了人们生活方式的巨大改变。这个时代不仅深刻影响了我们的生活,还改变了我们对信息科学的看法。社会的各个领域都在不断发展,人们的思维也在不断进步,与此同时,信息的需求也与…

紫光同创FPGA图像视频采集系统,基于OV7725实现,提供工程源码和技术支持

目录 1、前言免责声明 2、设计思路框架视频源选择OV7725摄像头配置及采集动态彩条HDMA图像缓存输入输出视频HDMA缓冲FIFOHDMA控制模块HDMI输出 3、PDS工程详解4、上板调试验证并演示准备工作静态演示动态演示 5、福利:工程源码获取 紫光同创FPGA图像视频采集系统&am…

分布式文件存储系统minio、大文件分片传输

上传大文件 1、Promise对象 Promise 对象代表一个异步操作,有三种状态: pending: 初始状态,不是成功或失败状态。fulfilled: 意味着操作成功完成。rejected: 意味着操作失败。 只有异步操作的结果,可以决定当前是哪一种状态&a…

Windows/Linux下进程信息获取

Windows/Linux下进程信息获取 前言一、windows部分二、Linux部分三、完整代码四、结果 前言 Windows/Linux下进程信息获取,目前可获取进程名称、进程ID、进程状态 理论分析: Windows版本获取进程列表的API: CreateToolhelp32Snapshot() 创建进程快照,…

CTF 入门指南:从零开始学习网络安全竞赛

文章目录 写在前面CTF 简介和背景CTF 赛题类型介绍CTF 技能和工具准备好书推荐 写作末尾 写在前面 CTF比赛是快速提升网络安全实战技能的重要途径,已成为各个行业选拔网络安全人才的通用方法。但是,本书作者在从事CTF培训的过程中,发现存在几…

ip的标准分类---分类的Ip

分类的 IP 即将 IP 地址划分为若干个固定类,每一类地址都由两个固定长度的字段组成。 其中第一个字段是网络号(net-id),它标志主机或路由器所连接的网络。一个网络号在整个因特网内必须是唯一的。 第二个字段是主机号&#xf…