分类神经网络1:VGGNet模型复现

目录

分类网络的常见形式

VGG网络架构

VGG网络部分实现代码


分类网络的常见形式

常见的分类网络通常由特征提取部分分类部分组成。

特征提取部分实质就是各种神经网络,如VGG、ResNet、DenseNet、MobileNet等。其负责捕获数据的有用信息,一般是通过堆叠多个卷积层和池化层来实现的,这些层有助于检测图像中的边缘、纹理和特征。

分类部分通常是一个全连接层,负责将特征提取部分输出的信息映射到最终的类别或标签。这些全连接层通常包括一个或多个隐藏层,以及一个输出层,其中输出层的节点数量等于任务中的类别数量。

VGG网络架构

论文原址:https://arxiv.org/pdf/1409.1556v6.pdf

VGG 网络是由牛津大学的Visual Geometry Group 开发的,其结构特点在于使用了多个 3x3 的小卷积核,并通过这些小卷积层的重复堆叠来构建网络,从而能够捕捉到更加复杂和抽象的特征表示。VGG 网络的模型结构如下:

VGG网络的核心架构可以分为以下几个部分:

  1. 输入层:VGG网络接受224x224像素的RGB图像作为输入。
  2. 卷积层:网络的前几层由多个卷积层组成,每个卷积层都使用3x3的卷积核来提取图像的特征。这些卷积层后面通常跟着一个2x2 最大池化,用于逐步减小特征图的空间尺寸,同时增加特征深度。
  3. 池化层:在卷积层之后,网络使用最大池化层来降低特征图的空间分辨率,这有助于减少计算量并提取更加抽象的特征。
  4. 全连接层:经过多个卷积和池化层之后,网络的特征图被展平并通过几个全连接层进行处理。全连接层的作用是将学习到的特征映射到最终的分类结果。
  5. 输出层:VGG网络的最后是一个softmax层,它将网络的输出转换为概率分布,以便进行多类别的分类任务。

VGG网络的一个显著特点是其深度,其相关配置信息如下:

VGG系列不同变体内容如下:

  • VGG A:这是一个基础的配置,没有特别独特的设计。
  • VGG A-LRN:在这个版本中,加入了局部响应归一化(LRN),这是一种在AlexNet中首次使用的技术。不过,LRN在当前的深度学习实践中已经较少被采用。
  • VGG B:相较于A版本,B版本增加了两个卷积层,以增强网络的学习能力。
  • VGG C:在B的基础上,C版本进一步增加了三个卷积层,但这些层使用的是1x1的卷积核。1x1卷积核可以看作是对输入特征图进行线性变换,有助于减少参数数量并增加非线性。
  • VGG D:D版本在C版本的基础上做了调整,将1x1的卷积核替换为3x3的卷积核,这个配置后来被称为VGG16,因为它总共有16层。
  • VGG E:在D版本的基础上,E版本进一步增加了三个3x3的卷积层,形成了VGG19,总共有19层。

从图中可以看出,随着网络深度的加深,模型变得更为复杂。通常来说,增加网络的深度可以增加模型的表示能力,使其能够学习到更复杂的特征和模式,从而在某些任务上取得更好的性能。然而,随着网络深度的增加,模型的参数数量也会增加,导致模型的复杂度增加,训练和推理的计算成本也会增加,同时可能会增加过拟合的风险。

VGG网络部分实现代码

废话不多说,直接上干货

import torch
import torch.nn as nn__all__ = ["VGG", "vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"]cfg = {'A': [64,     'M', 128,      'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],'B': [64, 64, 'M', 128, 128, 'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],'C': [64, 64, 'M', 128, 128, 'M', 256, 256, 256,      'M', 512, 512, 512,      'M', 512, 512, 512,      'M'],'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}class ConvBNReLU(nn.Module):def __init__(self, in_channels, out_channels, stride=1,  kernel_size=3, padding=1):super(ConvBNReLU, self).__init__()self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)self.bn = nn.BatchNorm2d(num_features=out_channels)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return xclass VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=True):super(VGG, self).__init__()self.features = featuresself.avgpool = nn.AdaptiveAvgPool2d((7, 7))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, num_classes),)if init_weights:self._initialize_weights()def forward(self, x):for layer in self.features:x = layer(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def make_layers(cfg):layers = nn.ModuleList()in_channels = 3for i in cfg:if i == 'M':layers.append(nn.MaxPool2d(kernel_size=2, stride=2))else:layers.append(ConvBNReLU(in_channels=in_channels, out_channels=i))in_channels = ireturn layersdef vgg11_bn(num_classes):model = VGG(make_layers(cfg['A']), num_classes=num_classes)return modeldef vgg13_bn(num_classes):model = VGG(make_layers(cfg['B']), num_classes=num_classes)return modeldef vgg16_bn(num_classes):model = VGG(make_layers(cfg['C']), num_classes=num_classes)return modeldef vgg19_bn(num_classes):model = VGG(make_layers(cfg['D']), num_classes=num_classes)return modelif __name__=='__main__':import torchsummarydevice = 'cuda' if torch.cuda.is_available() else 'cpu'input = torch.ones(2, 3, 224, 224).to(device)net = vgg16_bn(num_classes=4)net = net.to(device)out = net(input)print(out)print(out.shape)torchsummary.summary(net, input_size=(3, 224, 224))# Total params: 134,285,380

这只是一个网络架构部分实现代码,其中 cfg 列表是 VGG 卷积和池化后的通道数,大家可以结合 VGG 的配置信息图一起对比理解。希望对大家有所帮助呀!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/638042.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第一届 _帕鲁杯_ - CTF挑战赛

Mis 签到 题目附件: 27880 30693 25915 21892 38450 23454 39564 23460 21457 36865 112 108 98 99 116 102 33719 21462 21069 27573 102 108 97 103 20851 27880 79 110 101 45 70 111 120 23433 20840 22242 38431 22238 22797 112 108 98 99 116 102 33719 2…

【已解决】电脑设置notepad++默认打开txt

1、以管理员的方式打开notepad 步骤:打开设置 -> 首选项 -> 文件关联 2、 设置Notepad默认打开 按照以下步骤将Notepad设置为默认打开.txt文件: 右键单击任何一个.txt文件。选择“属性”。在“常规”选项卡中,找到“打开方式”&#…

STM32F1之I2C通信

目录 1. 简介 2. 硬件电路 3. IIC时序基本单元 3.1 发送一个字节 3.2 接收一个字节 3.3 发送应答 3.4 接收应答 1. 简介 I2C(Inter-Integrated Circuit)总线是由NXP Semiconductors(前身为Philips Semiconductor)…

Tomcat弱口令及war包漏洞复现(保姆级教程)

1.环境搭建 靶机:Ubuntu 安装参考:安装Ubuntu详细教程_乌班图安装教程-CSDN博客 vulhub docker搭建tomcat漏洞环境 参考:vulhub docker靶场搭建-CSDN博客 工具:burpsuite 2.漏洞复现 2.1弱口令爆破 进入http://192.168.143…

vscode 配置verilog环境

一、常用的设置 1、语言设置 安装如下插件,然后在config 2、编码格式设置 解决中文注释乱码问题。vivado 默认是这个格式,这里也设置一样。 ctrl shift p 打开设置项 3、插件信任区设 打开一个verilog 文件,显示是纯本文,没…

在centos系统中使用boost库

打开MobaXterm软件 下载 boost_1_85_0.tar.gz tar -zxvf boost_1_85_0.tar.gz解压缩成boost_1_85_0文件夹 双击arrayDemo.cpp 在里面可以编写代码 arrayDemo.cpp #include <boost/timer/timer.hpp> #include <boost/array.hpp> #include <cmath> #inc…

文章解读与仿真程序复现思路——电网技术EI\CSCD\北大核心《风电租赁储能参与电能-调频市场竞价策略》

本专栏栏目提供文章与程序复现思路&#xff0c;具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…

Vue3+TS版本Uniapp:封装uni.request请求配置

作者&#xff1a;前端小王hs 阿里云社区博客专家/清华大学出版社签约作者✍/CSDN百万访问博主/B站千粉前端up主 封装请求配置项 封装拦截器封装uni.request 封装拦截器 uniapp的封装逻辑不同于Vue3项目中直接使用axios.create()方法创建实例&#xff08;在create方法中写入请求…

[Algorithm][二分查找][在排序数组中查找元素的第一个和最后一个位置][x 的平方根]详细讲解

目录 1.在排序数组中查找元素的第一个和最后一个位置1.题目链接2.算法原理详解1.查找区间左端点2.查找区间右端点 3.代码实现 2.x 的平方根1.题目链接2.算法原理详解3.代码实现 1.在排序数组中查找元素的第一个和最后一个位置 1.题目链接 在排序数组中查找元素的第一个和最后…

Rust序列化和反序列化

Rust 编写python 模块 必备库 docker 启动 nginx 服务 NGINX 反向代理配置

Spring Boot | Spring Boot 默认 “缓存管理“ 、Spring Boot “缓存注解“ 介绍

目录: 一、Spring Boot 默认 "缓存" 管理 :1.1 基础环境搭建① 准备数据② 创建项目③ 编写 "数据库表" 对应的 "实体类"④ 编写 "操作数据库" 的 Repository接口文件⑤ 编写 "业务操作列" Service文件⑥ 编写 "applic…

01-服务与服务间的通信

这里是极简版&#xff0c;仅用作记录 概述 前端和后端可以使用axios等进行http请求 服务和服务之间也是可以进行http请求的spring封装的RestTemplate可以进行请求 用法 使用bean注解进行依赖注入 在需要的地方&#xff0c;自动注入RestTemplate进行服务和服务之间的通信 注…