pytorch学习笔记(十一)

优化器学习

把搭建好的模型拿来训练,得到最优的参数。

import torch.optim
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoaderdataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset, batch_size=1)
class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return x
#定义loss
loss = nn.CrossEntropyLoss()
tudui = Tudui()
#一开始时采用比较大的学习速率学习,后面用比较小的学习速率学习
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):#在每一轮学习之前都把loss设置成0#在每一轮的学习过程中计算的loss都加上去#这个数据是表示,在每一轮的学习的过程中在这一轮的整体的loss的求和,整体误差总和running_loss = 0.0for data in dataloader:imgs, targets = dataoutputs = tudui(imgs)result_loss = loss(outputs, targets)optim.zero_grad()#得到每一个可调参数的梯度result_loss.backward()optim.step()#损失函数没有已知在变化,原因是只有单个循环下,只看了一次数据,这一次看到的数据对你下一次看到的数据预测的影响不大# print(result_loss)running_loss = running_loss + result_lossprint(running_loss)

在debug的过程中选择最后三行,观察梯度变化

其中optim.step()会把每一步更新的梯度用于数据的更新

现有模型的使用和修改

参数:root (string) - ImageNet数据集的根目录。

split (string,可选)-数据集分割,支持train或val。

transform(可调用的,可选的)-一个函数/转换,接收PIL图像并返回转换后的版本。例如,变换。RandomCrop

target_transform (callable, optional) -一个函数/transform,接收目标并对其进行变换。

loader -加载给定路径的图像的函数。

这边看看VGG16,因为它的预训练数据集太大了,不好下载,这边采用CIFAR10代替ImageNet的方法。

然后发现他的线性层输出的特征是1000,也是分1000个类,而CIFAR10只有10个类,这需要对网络模型进行修改,两种思路进行修改。

(1)直接修改最后一个线性层(6),将输出特征改为10

(2)加个线性层(7),输入设置为1000,而输出设置为10

模型的保存和模型的加载

官方推荐的保存下来文件比较小

方式2输出的是一个字典形式,要恢复成网络结构,要新建这个模型,然后还要通过字典的形式重建。

另外要注意用方式1(陷阱)保存的时候要在加载的部分引入你定义的结构否则会报错

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/422541.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《WebKit 技术内幕》学习之六(2): CSS解释器和样式布局

2 CSS解释器和规则匹配 在了解了CSS的基本概念之后,下面来理解WebKit如何来解释CSS代码并选择相应的规则。通过介绍WebKit的主要设施帮助理解WebKit的内部工作原理和机制。 2.1 样式的WebKit表示类 在DOM树中,CSS样式可以包含在“style”元素中或者使…

最全笔记软件盘点!你要的笔记神器都在这里:手写笔记、知识管理、文本笔记、协作笔记等!

在当今的信息化社会中,人们对信息的处理速度越来越快,从工作到生活,我们都面临着大量信息的冲击。在这样的环境下,一个能够帮助我们管理、整理和储存信息的好工具显得尤为重要,而笔记软件恰恰可以满足这些需求。 在选…

中仕教育:国考调剂和补录的区别是什么?

国考笔试成绩和进面名单公布之后,考生们就需要关注调剂和补录了,针对二者之间的区别很多考生不太了解,本文为大家解答一下关于国考调剂和补录的区别。 1.补录 补录是在公式环节之后进行的,主要原因是经过面试、体检和考察&#…

在vscode中悄无声息地摸鱼

想法 作为前端开发者,大多数人都使用 VSCode,并且可能会找一些在 VSCode 中可以摸鱼的插件。我也尝试了一些: Zhihu On VSCode,知乎摸鱼。 daily anime,追番插件。 韭菜盒子,看股票、基金、期货实时数据…

GitHub README-Template.md - README.md 模板

GitHub README-Template.md - README.md 模板 1. README-Template.md 预览模式2. README-Template.md 编辑模式References A template to make good README.md. https://gist.github.com/PurpleBooth/109311bb0361f32d87a2 1. README-Template.md 预览模式 2. README-Templat…

Java 面向对象案例 02 (黑马)

代码: public class foodTest {public static void main(String[] args) {//1、构建一个数组food[] arr new food[3];//2、创建三个商品对象food f1 new food("apple","123",3.2,500);food f2 new food("pear","456",4…

大模型学习之书生·浦语大模型6——基于OpenCompass大模型评测

基于OpenCompass大模型评测 关于评测的三个问题Why/What/How Why What 有许多任务评测,包括垂直领域 How 包含客观评测和主观评测,其中主观评测分人工和模型来评估。 提示词工程 主流评测框架 OpenCompass 能力框架 模型层能力层方法层工具层 支持丰富…

JVM系列-3.类的生命周期

👏作者简介:大家好,我是爱吃芝士的土豆倪,24届校招生Java选手,很高兴认识大家📕系列专栏:Spring原理、JUC原理、Kafka原理、分布式技术原理、数据库技术、JVM原理🔥如果感觉博主的文…

计算机网络学习The next day

在计算机网络first day中,我们了解了计算机网络这个科目要学习什么,因特网的概述,三种信息交换方式等,在今天,我们就来一起学习一下计算机网络的定义和分类,以及计算机网络中常见的几个性能指标。 废话不多…

面试经典 150 题 - 多数元素

多数元素 给定一个大小为 n 的数组 nums ,返回其中的多数元素。多数元素是指在数组中出现次数 大于 ⌊ n/2 ⌋ 的元素。 你可以假设数组是非空的,并且给定的数组总是存在多数元素。 示例 1: 输入:nums [3,2,3] 输出&#xff1…

archlinux 如何解决安装以后没有声音的问题

今天安装完archlinux以后发现看视频没声音 检查一下是否有 /lib/firmware/intel/sof 发现没有 如果你也是这样的话,可以尝试安装: sudo pacman -S sof-firmware 重启后再看看有没有声音: reboot 反正我有声音了

中间件存储设计 - 数组与链表

文章目录 数组ArrayListLinkedListHashMap小结 中间件主要包括如下三方面的基础:数据结构、JUC 和 Netty,接下来,我们先讲数据结构。 数据结构主要解决的是数据的存储方式问题,是程序设计的基座。 按照重要性和复杂程度&#xf…