深入学习pytorch笔记

两个重要的函数

  • dir(): 一个内置函数,用于列出对象的所有属性和方法
    在这里插入图片描述

  • help():一个内置函数,用于获取关于Python对象、模块、函数、类等的详细信息
    在这里插入图片描述

Dateset类

  • Dataset:pytorch中的一个类,开发者在训练和测试时,用一个子类去继承Dataset类,继承和重写Dataset类中方法和属性,以加载数据集。
class Dataset(object):"""An abstract class representing a Dataset.All other datasets should subclass it. All subclasses should override``__len__``, that provides the size of the dataset, and ``__getitem__``,supporting integer indexing in range from 0 to len(self) exclusive."""def __getitem__(self, index):raise NotImplementedErrordef __len__(self):raise NotImplementedErrordef __add__(self, other):return ConcatDataset([self, other])
  • def getitem(self, index):必须重写,用于以加载数据集。
  • def len(self):可不重写,用于计算数据集中样本个数。
    在这里插入图片描述

TensorBoard

  • TensorBoard 是pytorch中一组用于数据可视化的工具,包含在TensorFlow库。
  • SummaryWriter类:用于在给定目录中创建事件文件,在训练时,将数据添加到文件中,用于显示。使用SummaryWriter类创建对象时,若没有给出事件文件名,则默认的事件文件名为run。

损失函数

  • torch.nn.loss():PyTorch 中的一个类,用于计算L1 损失函数,即计算了预测值与实际值之间的L1范数(即绝对差值)。
  • 在创建torch.nn.L1Loss(reduction)对象时,可以传入一个可选的参数reduction,它决定了如何从每个样本的损失中聚合得到最终的损失。
    1. reduction=‘mean’:计算所有样本损失的平均值作为最终损失。默认情况下,reduction参数的值为’mean’,即计算所有样本损失的平均值作为最终损失。
    2. reduction=‘none’:不进行任何聚合操作,直接返回每个样本的损失。
    3. reduction=‘sum’:计算所有样本损失的总和作为最终损失。
    4. reduction= ‘mean_none’: 计算所有样本损失的平均值,但是不除以样本数,即不进行归一化。
    5. reduction=‘sum_none’:计算所有样本损失的总和,但是不乘以样本数,即不进行归一化。
  • 在调用torch.nn.L1Loss()对象时,要传入预测值和实际值。
    在这里插入图片描述
  • torch.nn.MSELoss():PyTorch库中的一个类,用于计算均方误差。MSE损失函数的计算方式是:对于每个样本,计算预测值与真实值之间的平方差,然后取这些平方差的平均值。具体公式为:loss = 1/n Σ (y_pred - y_true)^2,其中n是样本数量。
    在这里插入图片描述
  • torch.nn.CrossEntropyLoss:是PyTorch库中的一个类,用于计算交叉熵损失。
  • 在创建对象时,torch.nn.CrossEntropyLoss()参数:
    1. weight: 类别权重。这是一个一维的tensor,用于为每个类别指定不同的权重。默认值是None,这时所有的类别权重都相等。如果指定了类别权重,那么在计算损失时,每个类别的损失将会根据其对应的权重进行加权平均。
    2. reduction: 损失的归约方式。这个参数决定了如何将交叉熵损失的值从样本级别降低到批次级别。可能的值有:‘none’(不进行归约,返回每个样本的交叉熵损失),‘mean’(对所有样本的交叉熵损失取平均),‘sum’(将所有样本的交叉熵损失相加)。默认值是’mean’。
    3. ignore_index: 被忽略的类别索引。如果设置了该参数,那么在计算交叉熵损失时,该类别对应的损失将被忽略。这个参数主要用于处理数据集中的无效类别或不需要分类的类别。默认值是-100。
  • 在调用torch.nn.CrossEntropyLoss的对象时,需要传入两个参数:
    1. input:这是一个一维或二维张量,表示模型的输出。对于每个输入样本,输出应该是一个长度为类别数量的向量,每个元素表示该类别与输入样本的相似度。
    2. target:这是一个一维张量,表示每个输入样本的正确类别标签。
      在这里插入图片描述

优化器(参数更新)

  • torch.optim.SGD:PyTorch 中的一个类,它实现了随机梯度下降(Stochastic Gradient Descent)算法。
  • 创建类对象时,torch.optim.SGD(params,lr,momentum,dampening,weight_decay,nesterov)的参数:
    1. params:要优化的参数,通常是模型中的参数。
    2. lr:学习率。控制参数更新的步长。默认值是0.01。
    3. momentum:动量。这个参数会考虑之前梯度的方向,使得优化器具有一定的"惯性",有助于加速训练。默认值是0。
    4. dampening:阻尼。这个参数可以防止动量过大导致震荡。默认值是0。
    5. weight_decay:权重衰减。可以防止过拟合,通过对参数本身进行惩罚来控制模型的复杂度。默认值是0,表示不进行权重衰减。
    6. nesterov:是否使用 Nesterov 动量。如果为 True,会使用 Nesterov 动量,否则使用标准 momentum。默认值是False
  • 创建优化器后,我们可以通过调用 optimizer.zero_grad() 清除之前的梯度,然后通过反向传播计算新的梯度,最后使用 optimizer.step() 更新模型的参数。

import torch
from torch import nn
from torch.nn import Sequential,Conv2d,MaxPool2d,Flatten
from torch.nn import Linear
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torchvision.transforms
from torch.utils.data import DataLoaderdataset = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset, batch_size=64)class MY_Dodule(nn.Module):def __init__(self):super(MY_Dodule,self).__init__()self.model = Sequential(Conv2d(3, 32, kernel_size=5, padding=2),MaxPool2d(2),Conv2d(32, 32, kernel_size=5, padding=2),MaxPool2d(2),Conv2d(32, 64, kernel_size=5, padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64,10))def forward(self,input):output = self.model(input)return outputmy_module = MY_Dodule()
loss = nn.CrossEntropyLoss()
optim = torch.optim.SGD(my_module.parameters(),lr=0.1)
for epoch in range(20):running_loss = 0.0for data in dataloader:images,targets = datainput = imagesoutput = my_module(input)  # 前向转播result_loss = loss(output,targets)  # 计算损失optim.zero_grad()  # 清除之前的梯度result_loss.backward() # 反向转播optim.step() #梯度更新running_loss += result_losspassprint(running_loss)pass

网络模型的使用和修改

  • torchvision.models.vgg16(pretrained,progress):PyTorch 中的一个类,是用来加载预训练的 VGG-16 模型的函数。

    1. pretrained:布尔型,决定是否从 PyTorch 的预训练模型库中加载训练好的权重。如果设为 True,则返回的模型会包含在大规模图像分类任务上训练得到的权重。如果设为 False,则模型不包含预训练的权重,你需要自己训练模型。默认为False。
    2. progress:布尔型,决定是否显示下载预训练模型过程的进度条。如果设为 True,则在下载预训练模型时会显示进度条。默认为True。
  • 在 VGG-16 模型中添加层:model是torchvision.models.vgg16()示例化对象,model.classifier.add_module(str,nn.Module)这个函数接受两个参数。

    1. 模块名称(str):这是你想要添加的模块的名称。你可以自己定义一个有意义的名称,以便在后续的代码中引用这个模块。
    2. 模块对象(nn.Module):这是你想要添加的模块本身。这个模块可以是任何PyTorch定义的神经网络层或者你自己定义的层。
  • 在 VGG-16 模型中修改层:model是torchvision.models.vgg16()示例化对象,model.classifier[n] = nn.Module

    1. n:VGG-16 模型中修改层的层号
    2. nn.Module:修改后的模块本身。这个模块可以是任何PyTorch定义的神经网络层或者你自己定义的层。
      在这里插入图片描述

网络模型的保存与读取

  • torch.save(model, ‘model.pth’):PyTorch 中的一个函数,模型model的权重和参数,保存在指定文件model.pth中。
  • model = torch.load(‘model.pth’):PyTorch 中的一个函数,根据model.pth文件,加载保存的模型并返回给变量 model
  • torch.save(model.state_dict(), ‘model.pth’): 将模型model参数(权重和偏置等,不包括模型的结构),以字典的形式保存到指定的文件 ‘model.pth’ 中。
  • model.load_state_dict(torch.load(‘model.pth’)):torch.load()函数读取文件中模型的参数信息,加载到model模型中。请注意,这种方式要求你在加载模型时已经知道模型model的结构。

模型训练流程(以CIFAR10为例)

  • 第一步:准备数据集,包括训练集和测试集
import torchvision# 准备训练集
train_data = torchvision.datasets.CIFAR10("dataset",train=True,transform=torchvision.transforms.ToTensor(),download=True)# 准备测试集
test_data = torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
  • 第二步:计算数据长度
# 计算数据集长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度:{}".format(train_data_size))
print("测试数据集的长度:{}".format(test_data_size))
  • 第三步:用dataloader()加载数据集,将数据集划分为批量子集
# dataloader()加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
  • 第四步:搭建神经网络,一般用一个单独python文件保存
import torch
from torch import nnclass My_Module(nn.Module):def __init__(self):super(My_Module,self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32 ,32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32,64,5,1,2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4,64),nn.Linear(64,10),)def forward(self,input):output = self.model(input)return outputif __name__ == '__main__':my_module = My_Module()input = torch.ones((64, 3, 32, 32))output = my_module(input)print(output.shape)
  • 第五步:创建网络模型
# 创建网络模型
my_module = My_Module()
  • 第六步:定义损失函数
loss_f = nn.CrossEntropyLoss()
  • 第七步:定义优化器,进行梯度下降
# 定义优化器,进行梯度下降
learning_rate = 0.01  # 学习效率
optimizer = torch.optim.SGD(my_module, lr=learning_rate)
  • 第八步:设置训练网络模型的一些参数
# 设置训练网络模型的一些参数
total_train_step = 0  # 记录训练次数
total_test_step = 0  # 记录测试次数
epoch = 10 # 训练的轮次
writer = SummaryWriter("P27")  # 添加tensorboard
  • 第九步:训练网络模型
# 训练网络模型
for i in range(epoch):print("------第{}轮训练开始------".format(i + 1))for data in train_dataloader:images ,targets = datainput = imagesoutput = my_module(input)  # 前向传播loss = loss_f(output, targets)  # 计算损失loss.backward()  # 反向转播optimizer.zero_grad()  #optimizer.step() # 梯度下降total_train_step = total_train_step + 1print("训练次数:{},loss:{}".format(total_train_step, loss.item()))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/214044.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于springboot实现大学生就业服务平台系统项目【项目源码】计算机毕业设计

基于springboot实现大学生就业服务平台系统演示 Java技术 Java是由SUN公司推出,该公司于2010年被oracle公司收购。Java本是印度尼西亚的一个叫做爪洼岛的英文名称,也因此得来java是一杯正冒着热气咖啡的标识。Java语言在移动互联网的大背景下具备了显著…

产品经理面试必看!To B和To C产品的隐秘差异,你了解多少?

大家好,我是小米,一位对技术充满热情的产品经理。最近在和小伙伴们交流中发现一个热门话题:To B(面向企业)和To C(面向消费者)的产品经理究竟有何异同?这可是我们产品经理面试中的经…

【SpringCloud】微服务的扩展性及其与 SOA 的区别

一、微服务的扩展性 由上一篇文章(没看过的可点击传送阅读)可知, 微服务具有极强的可扩展性,这些扩展性包含以下几个方面: 性能可扩展:性能无法完全实现线性扩展,但要尽量使用具有并发性和异步…

视频剪辑新招:批量随机分割,分享精彩瞬间

随着社交媒体的普及,短视频已经成为分享生活、交流信息的重要方式。为制作出吸引的短视频,许多创作者都投入了大量的时间和精力进行剪辑。然而,对于一些没有剪辑经验的新手来说,这个过程可能会非常繁琐。现在一起来看云炫AI智剪批…

React16中打印事件对象取不到值的现象及其原因分析

React16中打印事件对象取不到值的现象及其原因分析 一、背景 在最近的开发过程中&#xff0c;遇到了一个看起来匪夷所思的问题❓&#xff1a; <Inputplaceholder"请输入"onChange{(e) > {console.log(e:, e)}}onKeyDown{handleKeyDown} />此时按理来说我…

外汇天眼:多名投资者账户被恶意清空,远离volofinance!

最近&#xff0c;外汇平台volofinance因有多名投资者投诉&#xff0c;“荣幸”成为外汇天眼黑平台榜单中的一员&#xff0c;那么volofinance到底做了什么导致投资者前来投诉曝光呢&#xff1f; 起底volofinace 在网络搜索中&#xff0c;关于volofinance的信息少之又少&#xf…

绽放独特魅力,点亮美好生活

2023年10月至11月&#xff0c;由益田社区党委主办、深圳市罗湖区懿米阳光公益发展中心承办&#xff0c;深圳市温馨社工服务中心协办的“2023年益田社区益田佳人--女性成长课堂”项目顺利完成&#xff0c;此项目分为四个主题&#xff0c;分别是瑜伽、健身操、收纳、花艺技能&…

2023-11-24 事业-代号s-行业数据研报网站-记录

摘要&#xff1a; 2023-11-24 事业-代号s-行业数据研报网站-记录 行业数据研报网站 1、萝卜投研&#xff1a;https://robo.datayes.com 看数据、下载研报、上市公司PE/PB研究等。2、镝数聚&#xff1a;www.dydata.io 全行业数据&报告查找下载平台&#xff0c;覆盖100行业报…

基于51单片机超声波测距测液位及报警设计

**单片机设计介绍&#xff0c; 基于51单片机超声波测距测液位及报警设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于51单片机的超声波测距测液位及报警系统是一种用于测量储液罐或水箱中液位高度并进行液位监测和报警的设…

Element UI的Tabs 标签页位置导航栏去除线条

在实际开发中&#xff0c;我们调整了相关样式&#xff0c;导致导航栏的相关样式跟随不上&#xff0c;如下图所示&#xff1a; 因为我跳转了前边文字的样式并以在导航栏添加了相关头像&#xff0c;导致右边的线条定位出现问题&#xff0c;我在想&#xff0c;要不我继续调整右边…

linux基础5:linux进程1(冯诺依曼体系结构+os管理+进程状态1)

冯诺依曼体系结构os管理 一.冯诺依曼体系结构&#xff1a;1.简单介绍&#xff08;准备一&#xff09;2.场景&#xff1a;1.程序的运行&#xff1a;2.登录qq发送消息&#xff1a; 3.为什么需要内存&#xff1a;1.简单的引入&#xff1a;2.计算机存储体系&#xff1a;3.内存的意义…

Sublime Text 4168最新代码编辑

Sublime Text是一款功能强大的文本编辑器&#xff0c;具有以下主要功能&#xff1a; 支持多种编程语言的语法高亮和代码自动完成功能&#xff0c;包括Python、JavaScript、HTML、CSS等。提供代码片段&#xff08;Snippet&#xff09;功能&#xff0c;可以将常用的代码片段保存…