零基础学习人工智能—Python—Pytorch学习(十二)

news/2024/12/16 11:07:38/文章来源:https://www.cnblogs.com/kiba/p/18609581

前言

本文介绍使用神经网络进行实战。
使用的代码是《零基础学习人工智能—Python—Pytorch学习(九)》里的代码。

代码实现

mudule定义

首先我们自定义一个module,创建一个torch_test17_Model.py文件(这个module要单独用个py文件定义),如下:

import torch.nn as nn
import torch.nn.functional as Fclass ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16*5*5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xreturn x

module创建

编写创建module的py文件,代码如下:

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch_test17_Model as tmdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')input_size = 784 
hidden_size = 100
num_classes = 10
batch_size = 100
learning_rate = 0.001
num_epochs = 200 # 要训练200-400轮效果最好transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)train_loader = torch.utils. data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) model = tm.ConvNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)n_total_steps = len(train_loader)
print("number total epochs(训练的回合):",num_epochs)
print("number total steps(训练的次数):",n_total_steps)for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):# images.shape: torch.Size([100, 3, 32, 32]) # images张量的四个维度是(B, C, H, W)# B 是批量大小(即图像的数量)。# C 是图像的通道数(例如,RGB 图像的通道数是 3)。# H 和 W 分别是图像的高度和宽度。print("images.shape:", images.shape) #100行,后面的维度是3,32,32。这个是图片信息。# lables是对应images这100个图片的标签print("labels.shape:", labels.shape)print("labels[0].item():", labels[0].item())  # 输出例子 labels[0].item()=6images = images.to(device)labels = labels.to(device)# 前向传播outputs = model(images)loss = criterion(outputs, labels)print("loss.item()",loss.item())  # 输出例子 loss.item()=2.300053596496582# 逆向传播和优化optimizer.zero_grad()loss.backward() ##执行逆向传播 会使用criterion的函数关系求偏导,然后把x的值,带入偏导公式求值,然后再乘以loss,得到新x值optimizer.step()print(f'训练轮次Epoch [{epoch}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')print('==================')
print('训练结束')filePath = "model.pth" #没有路径,会保存到python文件所在目录
torch.save(model, filePath)
print('保存完成')

代码会输出loss的值,我们要重点关注这个值。
Loss 值越大,表示模型的预测与真实标签之间的差距较大,模型的性能较差。
Loss 值越小,表示模型的预测更接近真实标签,性能逐渐提高。
即,loss值接近0的时候,这个模型就可以用了。

module使用

编写使用module验证图片的py文件,注意要引用torch_test17_Model.py文件,代码如下:

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch_test17_Model as tmdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')batch_size = 100transform = transforms.Compose([transforms.Resize((32, 32)),# 如果预测时处理的图片尺寸与训练时不同,如评估输入的图片尺寸为 [100, 3, 64, 64],而模型训练使用的尺寸是 [100, 3, 32, 32],可以用這個转换一下transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)filePath = "model.pth" #没有路径,会保存到python文件所在目录
model = torch.load(filePath,weights_only=False)
model.eval() # 切换到评估模式############################使用阈值判断######################################
threshold = 0.7  # 设定一个阈值,表示模型的信心度,用阈值判断的话,要求模型必须更精确,如果只是两轮的训练,会出现全部判定不过去的情况
with torch.no_grad():for images, labels in test_loader:print("############################判断######################################")images = images.to(device)labels = labels.to(device)outputs = model(images)print("outputs.shape",outputs.shape)# 计算 softmax 概率probabilities = F.softmax(outputs, dim=1)max_probs, predicted = torch.max(probabilities, 1)for i in range(len(predicted)):if max_probs[i] < threshold:  # 如果置信度低于阈值,认为是未知类别print(f"图片 {i} 被认为是未知类别,置信度 {max_probs[i]:.4f}")else:print(f"图片 {i} 被认为是类别 {predicted[i]},置信度 {max_probs[i]:.4f}")

判断图片是什么的时候,使用阈值模式。

结语

到此,我们对于神经网络,卷积神经网络,深度网络都有了一定了解。
然后我们就可以继续学习transformer了。


传送门:
零基础学习人工智能—Python—Pytorch学习—全集


注:此文章为原创,任何形式的转载都请联系作者获得授权并注明出处!



若您觉得这篇文章还不错,请点击下方的【推荐】,非常感谢!

https://www.cnblogs.com/kiba/p/18609581

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/853682.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

700PB数据的数仓依然“快稳省”!ByteHouse这本白皮书揭秘关键(内附下载链接)

12月10日,《火山引擎ByteHouse云数仓产品白皮书》在线上发布。在数字经济蓬勃发展的今天,企业面临着数据量爆炸性增长、数据分析需求日益复杂的双重挑战。传统的数据仓库解决方案已经难以满足企业对数据处理速度和灵活性的高要求。为了应对这些挑战,火山引擎于2021年正式推出…

子查询关联条件字段没有指定表的别名导致的查询结果不正确的问题

子查询关联查询问题,子查询关联条件字段没有指定表的别名导致的查询结果不正确的问题首先介绍一下表结构和背景;有两个数据库表,供应商XX任务主表和供应商等级变更记录表; 等级表里面有多个任务,两张表是通过同名称的字段,supplier_id关联; ①SQL是XX任务表关联供应商等…

Xinference环境搭建推理测试

引子 写了很多篇开源大模型的环境部署与推理搭建,截止到目前,开源大模型已经发展较为完善。个人觉得,产品和项目维度来看更多的是如果去落地实现,也就是大模型的最后一公里的应用开发。最近看到Xinference一个开源很火的推理框架。OK,那就让我们开始吧。 一、框架介绍 Xin…

前端工程化_CSS 工具链_学习笔记

本文主要介绍了 CSS 工具链,可以看出工具链的出现都是为了解决语言的问题,文中就介绍了预处理器和后处理器,预处理器主要介绍了 sass,并举了星空这个例子,sass 是通过与预编译器编译成 css 后给 html 使用;后处理器则介绍了 postcss,其中 postcss 和 babel 类似,都有很…

车载以太网TSN设计及测试解决方案

智能汽车电子电气架构全面向中央+区域式发展,车载通信新技术是新架构技术栈的重要组成部分。车载以太网时间敏感网络TSN技术凭借其低延时、高可靠的特点获得多家OEM的认可。依赖多年技术研发及数十个项目的实践积累,经纬恒润可为客户提供全面、专业且本土化的TSN设计与测试解…

看板软件:跨境电商圣诞营销加速器

看板软件在跨境电商中发挥着多重作用,特别是在圣诞节这一关键销售时期。通过清晰有序的任务管理、灵活适配的自定义功能、高效的信息整合与数据分析、以及精准有效的营销策略应用,看板软件显著提升了跨境电商团队的协作效率和销售能力。圣诞节作为全球最重要的购物节日之一,…

LameUI:轻量级嵌入式图形用户界面的绝佳选择

在信息技术迅猛发展的今天,嵌入式系统逐渐成为各种智能设备的核心。这些系统往往面临资源有限的挑战,因此在开发用户界面时,使用轻量级、易于实现的库显得尤为重要。在这种背景下,LameUI 应运而生。作为一个轻量级且平台无关的图形用户界面库,LameUI 旨在为开发者提供简便…

分享图片

测试图片分享

LT1121IST-5#TRPBF 规格书 数据手册具有关断功能的微功率低压差稳压器芯片

LT1121/LT1121-3.3/LT1121-5是具有关断功能的微功率低压差稳压器。这些设备能够以0.4V的压降提供150mA的输出电流。这些设备设计用于电池供电系统,低静态电流(30A运行,16A关断)使其成为理想的选择。静态电流得到良好控制,不会像许多其他低压差PNP稳压器那样在压降时上升。…

OPA828IDR OPA2828 数据手册一款低失调电压、低温漂、低噪声输入运算放大器芯片

OPA828 和 OPA2828 (OPAx828) JFET 输入运算放大器是 OPA627 和 OPA827 的下一代产品,兼具高速度、高直流精度和高交流性能。这些运算放大器可提供低失调电压、低温漂、低偏置电流和低噪声,噪声仅为60nVRMS 0.1Hz 至 10Hz。OPAx828 在 4V 至 18V的宽电源电压范围内工作,每通…

vue2 脚手架安装及使用

1.安装npm install -g @vue/cli 2.查看版本vue -V 3.使用3.1 命令形式vue create my-project 3.2可视化操作

.NET8升级.NET9,CodeFirst模式迁移Add-Migration执行Update-DataBase报错

在做netcore开发时,如果net8一直是正常的,只升级了一下框架net9,在使用Entity Framework Core的Code First模式进行迁移时,执行Add-Migration后尝试使用Update-DataBase时出现了如下错误。Unhandled exception. System.InvalidOperationException: An error was generated …