卷积神经网络--猫狗系列之构建模型【ResNet50】

在上一期:卷积神经网络--猫狗系列之下载、导入数据集,如果测试成功就说明对数据的预处理工作已经完成,接下来就是构建模型阶段了:

据说建立一个神经网络模型比较简单,只要了解了各层的含义、不同层之间参数的传递等等,那么一个完整的网络模型就可以被容易地构建出来。(这对于我这种初学的同学来说嘎嘎困难哈哈)

不扯了,加载一个预训练模型---ResNet50:

借助torchvision库,能够很容易的获得一组已经训练好的模型,这些模型大多数接收一个称为pretrained的参数,当这个参数为True时,它会下载为ImageNet分类问题调整好的权重。

network1=models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

紧接着,我们需要冻结所有层,所有的权重不会随着训练而更新:

for param in network1.parameters():    param.requires_grad=False

然后!由于这个预训练模型不是专门针对这样猫猫狗狗的二分类问题,所以,我们需要将最后一层的输出特征从1000改为2。(默认是1000类)

首先,我们需要知道最后一层的名字,通过network1查看一下,告诉大家怎么操作:

将network1输在pycharm中,把光标放在这一行,然后快捷键【shift+alt+E】(也可以鼠标选中然后右键专门运行这一行,这其实就是pycharm的交互模式)

得到这样的界面:

【可以知道最后一层是一个全连接层,名为fc。其他的东西以后有时间再分析,就是一些卷积、池化以及激活操作】

所以,我们就要将最后一层替换为输出特征为2的全连接层:

import torch.nn as nnnetwork1.fc=nn.Linear(2048,2)

此时,该层为新的层,所以它的requires_grad=True,这样整个网络仅有这一层可以更新权重。

打印更新后的网络:

该猫猫狗狗的模型构建代码总结:

from torchvision import modelsimport torch.nn as nnimport torch.optim as optim#网络搭建network1=models.resnet50(weights=models.ResNet50_Weights.DEFAULT)for param in network1.parameters():    param.requires_grad=Falsenetwork1.fc=nn.Linear(2048,2)

利用已经训练好的模型主要目的是它能够提取出非常好的特征,最后一层接收前面层提取的特征,然后误差反向传播,仅更新这一层的权重,不断迭代。

【此猫狗系列会继续更】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/4050.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Thunder送书 | 第三期 】「Python系列丛书」

文章目录 前言《Python高效编程——基于Rust语言》《Python从入门到精通》《Python Web深度学习》《Python分布式机器学习》文末福利 | 赠书活动 前言 Thunder送书第三期开始啦!前面两期都是以【文末送书】的形式开展,本期将赠送Python系列丛书&#xff…

html_css模拟端午赛龙舟运动

文章目录 ⭐前言💖 样式布局💖 添加龙舟💖 添加css_animation运动 ⭐结束 ⭐前言 大家好,我是yma16,本期给大家分享css实现赛龙舟运动。 💖 样式布局 风格:卡通 首先采用一张包括水元素的照片…

docker安装rabbitMQ,JAVA连接进行生产和消费,压测

1.docker安装 docker安装以及部署_docker bu shuminio_春风与麋鹿的博客-CSDN博客 2.doker安装rabbitMQ 使用此命令获取镜像列表 docker search rabbitMq 使用此命令拉取镜像 docker pull docker.io/rabbitmq:3.8-management 这里要注意,默认rabbitmq镜像是…

【MySQL】表的操作

目录 一、创建表 1、创建规则 2、创建案例 二、查看表结构 三、修改表 1、更改表名 2、 向表中插入数据 3、在表中添加字段 4、修改字段属性 5、从表中删除字段 6、修改字段名字 四、删除表 一、创建表 1、创建规则 CREATE TABLE table_name (field1 datatype,fi…

【C++】C++关于异常的学习

文章目录 C语言传统的处理错误的方式一、异常的概念及用法二、自定义异常体系总结 C语言传统的处理错误的方式 传统的错误处理机制: 1. 终止程序,如 assert ,缺陷:用户难以接受。如发生内存错误,除 0 错误时就会终止…

PyQt中数据库的访问(一)

访问数据库的第一步是确保ODBC数据源配置成功,我接下来会写数据源配置的文章,请继续关注本栏! (一)数据库连接 self.DBQSqlDatabase.addDatabase("QODBC") self.DB.setDatabaseName("Driver{sqlServer…

chatGPT AI对话聊天绘画系统开发:打开人工智能AI社交聊天系统开发新时代

人工智能技术的快速发展和普及,催生了众多创新应用,其中,AI社交聊天系统成为当下市场的热门话题,本文将详细介绍开发属于自己的ChatGPT的过程,并探讨当下市场因Chat AI聊天系统所带来的影响性。 AI社交聊天系统的潜力与…

云原生(第一篇)k8s-组件说明

k8s是什么? go语言开发的开源的跨主机的容器编排工具;全称是kubernetes; k8s的组件: master: ①kube-apiserver 所有服务统一的访问入口,无论对内还是对外; ②kube-controller-manager 资源控…

C++11新特性 智能指针

智能指针 nuique_ptr特点不允许拷贝构造和赋值运算符重载-> () *unique_ptr 删除器仿写删除文件删除普通对象 shared_ptr特点示意图仿写shared_ptr删除器部分特化拷贝构造 移动构造 && 左值赋值 和移动赋值完整实现 weak_ptr特点weak_ptr 实现解决循环引用弱指针一个…

事务

事务回顾MySQL事务Spring事务实现编程式事务实现:声明式事务 Transactional 注解作用范围及名称(value/transactionManager)隔离级别:isolation超时时间:timeout修改只读事务指定异常异常捕获情况 事务失效场景Transac…

九、ElasticSearch 运维 -集群维度

1. 查看集群健康 用于简单的判断集群的健康状态,集群内的分片的分配迁移情况。 GET _cluster/health-------------------------Respond----------------------------- {"cluster_name" : "test-jie","status" : "green",…

【Python】 Windows上通过git bash执行python卡住的解决方法

解决方法 编辑 C:\Program Files\Git\etc\profile.d\aliases.sh,将python2.7改成python 编辑完成后,重启git bash, 输入python即可 参考 https://blog.csdn.net/ofreelander/article/details/112058975