VGG16神经网络搭建

一、定义提取特征网络结构

将要实现的神经网络参数存放在列表中,方便使用。

数字代表卷积核的个数,字符代表池化层的结构

cfgs = {"vgg11": [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

二、 定义提取特征网络

如果遍历过程中 v== 'M',就是定义池化层,后面的卷积核与stride步距都是网络的默认参数。

数字代表的就是定义卷积层,然后与激活函数链接在一起。

最后返回时,以非关键字参数的形式传入。

def make_features(cfg: list):layers = []in_channels = 3for v in cfg:if v == 'M':layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)layers += [conv2d, nn.ReLU(True)]in_channels = vreturn nn.Sequential(*layers)

三、初始化网络

传入参数features,class_num,是否需要初始化权重。

定义分类网络结构,dropout方法缓解过拟合问题,再全连接核relu激活函数链接起来。

如果需要初始化权重,那么就会进入初始化权重的函数中。

class VGG(nn.Module):def __init__(self, features, class_num=1000, init_weight=False):super(VGG, self).__init__()self.features = featuresself.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(512*7*7, 2048),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(True),nn.Linear(2048, class_num))if init_weight:self._initialize_weights()

 四、初始化权重函数

这个函数会遍历网络的每一个子模块。

如果遍历的当前层是一个卷积层,那么这个方法会初始化卷积核的权重,如果采用了偏置,那就默认初始化为0.

如果遍历的当前层是全连接层,也是用这个方法去初始化全连接层的权重,并将偏置设置为0.

    def _initialize_weights(self):for m in self.modules():  # 遍历模块中的每一个子模块if isinstance(m, nn.Conv2d):nn.init.xavier_uniform_(m.weight)if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)nn.init.constant_(m.bias, 0)

五、定义正向传播

x:输入的图像数据

features:提取网络特征结构

flatten:展平处理。因为第0个维度是batch,所以我们从第一个维度开始展平

经过分类网络结构后返回

    def forword(self, x):x = self.features(x)x = torch.flatten(x, strat_dim=1)x = self.classifier(x)return x

六、实例化模型

传入参数model_name:实例化给定的配置模型。

将key值传入字典当中

通过VGG这个类来实例化这个网络

features通过make_features这个函数来实现

最后创建对象实现VGG神经网络的搭建。 

def vgg(model_name="vgg16", **kwargs):try:cfg = cfgs[model_name]except:print("waring: model number {} not in cfgs dict".format(model_name))model = VGG(make_features(cfg), **kwargs)return  modelvgg_model = vgg(model_name='vgg13')

 运行成功,网络搭建成功。

 全部代码

import torch.nn as nn
import torchclass VGG(nn.Module):def __init__(self, features, class_num=1000, init_weight=False):super(VGG, self).__init__()self.features = featuresself.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(512*7*7, 2048),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(True),nn.Linear(2048, class_num))if init_weight:self._initialize_weights()def forword(self, x):x = self.features(x)x = torch.flatten(x, strat_dim=1)x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():  # 遍历模块中的每一个子模块if isinstance(m, nn.Conv2d):nn.init.xavier_uniform_(m.weight)if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)nn.init.constant_(m.bias, 0)def make_features(cfg: list):layers = []in_channels = 3for v in cfg:if v == 'M':layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)layers += [conv2d, nn.ReLU(True)]in_channels = vreturn nn.Sequential(*layers)cfgs = {"vgg11": [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}def vgg(model_name="vgg16", **kwargs):try:cfg = cfgs[model_name]except:print("waring: model number {} not in cfgs dict".format(model_name))model = VGG(make_features(cfg), **kwargs)return  modelvgg_model = vgg(model_name='vgg13')

 全部代码与分开模块的顺序不同,但不影响最终实现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/571839.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LIS、LCS算法模型

文章目录 1.LCS算法模型2.LIS算法模型 1.LCS算法模型 LCS问题就是给定两个序列A和B,求他们最长的公共子序列。 在求解时,我们会设dp[i][j]表示为A[1 ~ i]序列和B[1 ~ j]序列中(不规定结尾)的最长子序列的长度。 if(a[i]b[i]) dp…

MFC标签设计工具 图片控件上,移动鼠标显示图片控件内的鼠标xy的水平和垂直辅助线要在标签模板上加上文字、条型码、二维码 找准坐标和字体大小 源码

需求:要在标签模板上加上文字、条型码、二维码 找准坐标和字体大小 我生成标签时,需要对齐和 调文字字体大小。这工具微调 能快速知道位置 和字体大小。 标签设计(点击图片,上下左右箭头移动 或-调字体) 已经够用了,滚动条还没完…

静态代理,jdk动态代理,cglib动态代理

文章目录 静态代理动态代理jdk动态代理JDK生成的动态代理类大概源码cglib动态代理 代理模式就是用代理对象代替真实对象去完成相应的操作,并且能够在操作执行的前后对操作进行增强处理。 静态代理 mybatis使用的就是静态代理,相比动态代理,…

Mamba: Linear-Time Sequence Modeling with Selective State Spaces(论文笔记)

What can I say? 2024年我还能说什么? Mamba out! 曼巴出来了! 原文链接: [2312.00752] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (arxiv.org) 原文笔记: What: Mamba: Linear-Time …

STL的基本概念

一、STL的诞生 长久以来,软件界一直希望建立一种可重复利用的东西 C的面向对象和泛型编程思想,目的就是复用性的提升 面向对象的三大特性(简单理解) 封装:把属性和行为抽象出来作为一个整体来实现事和物 继承:子类继承父类&a…

Gui guider使用自定义字体总结

在实际开发中,我们通常是使用自定义字体。 在 LVGL 中,用户需要使用自定义的字库,其实现方法可分为两类: ① 通过 C 语言数组(内部读取); ② 通过文件系统读取字库(外部读取&#…

【CAD建模号小技巧】边缘尖角光滑处理方法

教大家一个处理模型边缘的方法,处理后模型更美观,更不易坏,而且有些零件还必须经过这样的处理。 咱们看一个未经过边缘处理的模型,边缘是尖的,摸到会刺伤,一些接近刀口形状。 更危险了,所以要进…

javascript基础代码练习

一、输入新增病例数&#xff0c;累计确诊病例数&#xff0c;14天内聚集性疫情发生天数。新增或者累计确诊病例为0则该地区为低风险地区。新增大于0且累计确诊&#xff1c;50或者累计大于50且14天内聚集性疫情发生天数为0的地区为中风险地区。其他情况为高风险地区。 <!DOCT…

大数据开发(离线实时音乐数仓)

大数据开发&#xff08;离线实时音乐数仓&#xff09; 一、数据库与ER建模1、数据库三范式2、ER实体关系模型 二、数据仓库与维度建模1、数据仓库&#xff08;Data Warehouse、DW、DWH&#xff09;1、关系型数据库很难将这些数据转换成企业真正需要的决策信息&#xff0c;原因如…

C语言程序练习——汉诺塔递归

1. 题目 在终端输入汉诺塔层数n&#xff0c;实现将n层汉诺塔通过三座塔座A、B、C进行排列 2. 代码 #include <stdio.h>int hannuota(int len, int str, int tmp, int dst) {if (1 len){printf("%c -> %c\n", str, dst);}else{hannuota(len-1, str, dst, …

Python更改Word文档的页面大小

页面大小确定文档中每个页面的尺寸和布局。在某些情况下&#xff0c;您可能需要自定义页面大小以满足特定要求。在这种情况下&#xff0c;Python可以帮助您。通过利用Python&#xff0c;您可以自动化更改Word文档中页面大小的过程&#xff0c;节省时间和精力。本文将介绍如何使…

Python---Numpy学习

首先&#xff0c;先来认识一下Numpy数组对象&#xff0c;以及如何创建它 import numpy as np# 1.认识数组对象 # 指定取值范围和跨度创建数组对象 # 创建一个3行4列的数组 data np.arange(12).reshape(3, 4)print(data)print(type(data))# 维度 print(data.shape)# 维度的个数…