pytorch深度学习分类代码简单示例

train.py代码如下

import torch
import torch.nn as nn
import torch.optim as optimmodel_save_path = "my_model.pth"# 定义简单的线性神经网络模型
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.output = nn.Linear(2, 4)   # 输入2个特征,输出4个类别def forward(self, x):x = self.output(x)return xdef main():# 数据点x = torch.tensor([[0, 0], [0, 10], [10, 0], [10, 10]], dtype=torch.float32)y = torch.tensor([0, 1, 2, 3], dtype=torch.long)# 初始化模型model = MyModel()# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.01)# 训练模型num_iterations = 10000  # 迭代次数for i in range(num_iterations):model.train()# 前向传播:计算预测输出y_pred = model(x)# 计算损失loss = criterion(y_pred, y)# 输出每1000次迭代的损失值if i % 1000 == 0:print(f"迭代 {i},损失:{loss.item():.4f}")# 反向传播与梯度更新optimizer.zero_grad()  # 清除梯度loss.backward()  # 计算梯度optimizer.step()  # 更新参数# 打印优化后的权重和偏置print("\n优化后的权重和偏置:")for name, param in model.named_parameters():if param.requires_grad:print(f"{name} = {param.data.numpy()}")# 保存模型
    torch.save(model.state_dict(), model_save_path)print(f"模型已保存到 {model_save_path}")if __name__ == "__main__":main()
View Code

运行结果

test.py代码如下

import numpy as np
import torch
from torch import nnfrom train import MyModel, model_save_path# 加载模型
loaded_model = MyModel()
loaded_model.load_state_dict(torch.load(model_save_path))
loaded_model.eval()  # 切换到评估模式# 定义预测数据
input_data = [0, 9]# 使用加载的模型进行预测
x_new = torch.tensor([input_data], dtype=torch.float32)  # 新数据点
y_new_pred = loaded_model(x_new)  # 计算预测值# 使用softmax计算每个类别的概率
softmax = nn.Softmax(dim=1)
y_new_pred_probs = softmax(y_new_pred)# 找到预测的类别
predicted_class = torch.argmax(y_new_pred_probs, dim=1)# 将概率分布四舍五入到三位小数
y_new_pred_probs_rounded = np.round(y_new_pred_probs.detach().numpy(), 3)print(f"\n对x = {input_data}的预测类别:{predicted_class.item()}")
print(f"预测类别的概率分布:{y_new_pred_probs_rounded}")# 打印权重和偏置
weights = loaded_model.output.weight  # 获取输出层权重
bias = loaded_model.output.bias  # 获取输出层偏置print(f"\n模型权重:\n{weights}")
print(f"\n模型偏置:\n{bias}")# 计算input_data * 模型权重 + 模型偏置
with torch.no_grad():linear_output = x_new @ weights.t() + biasprint(f"\ninput_data * weights + bias ={linear_output.numpy()}")# 手动计算Softmax概率分布
linear_output_np = linear_output.numpy()
exp_output = np.exp(linear_output_np)
softmax_output_manual = exp_output / np.sum(exp_output)print(f"\n手动计算的Softmax概率分布:{softmax_output_manual}")
print(f"手动计算的预测类别:{np.argmax(softmax_output_manual)}")
View Code

运行结果

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/779154.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

轻松上手Markdown进阶:揭秘那些让你事半功倍的小秘诀!

讲讲其他关于 Markdown 的奇技淫巧110.其他Markdown技巧 讲讲其他关于 Markdown 的杂技。 ‍ ‍ ‍ Slidev 官网:cn.sli.dev/guide Slidev 是一款专门为开发者打造的演示文稿工具,目前在 Github 上已有 23K+Star​。 通过 Slidev,我们只要使用熟悉的 Markdown 就可以做出炫酷…

Yolo便捷GUI工具

1. 一键搭建服务2. 代码调用

深度解读昇腾CANN小shape算子计算优化技术,进一步减少调度开销

摘要:Host调度模式下,GE将模型中算子的执行单元划分为Host CPU执行与Device(昇腾AI处理器)执行两大类。 本文分享自华为云社区《深度解读昇腾CANN小shape算子计算优化技术,进一步减少调度开销》,作者:昇腾CANN。 GE(Graph Engine)将模型的调度分为Host调度与下沉调度两…

正运动控制

一、IP设置1.以太网IP设置:要和板卡IP在同一个IP段 2.注意:不能和板卡IP相同,不然会冲突 3.查询板卡IP是否存在,通过cmd输入:ping ip注意:电脑和板卡连接不上,可能是板卡和电脑不在同一个IP段,或者没有扫描找到运控板卡IP电脑设置 控制面板 >> 更改适配器设置以太…

windows操作系统通过nvm安装pm2,并解决不是内部或外部命令的解决方案

在Windows环境中安装nvm(Node Version Manager,Node版本管理器)的步骤如下: 一、下载nvm访问nvm的GitHub发布页面:前往nvm-windows的GitHub发布页面下载最新版本的nvm安装包。https://github.com/coreybutler/nvm-windows/releases下载nvm安装包:在发布页面中找到适合您系…

MySQL UDF 提权初探

MySQL UDF 提权初探 对 MySQL UDF 提权做一次探究,什么情况下可以提权,提取的主机权限是否跟mysqld进程启动的主机账号有关 数据库信息 MySQL数据库版本:5.7.21 UDF UDF:(User Defined Function) 用户自定义函数,MySQL数据库的初衷是用于方便用户进行自定义函数,方便查询一…

特殊字符,十六进制 0xa0导致的搜索问题

导致后端在处理的时候出现一些错误本文来自博客园,作者:chuangzhou,转载请注明原文链接:https://www.cnblogs.com/czzz/p/18346469

电路基础知识——常见晶振电路

电路基础知识——常见晶振电路 本文介绍了有源和无源晶振的特性,包括精度、稳定性、引脚配置以及晶振的选型参数,如工作电压、输出电平、频率精度等。此外,还讨论了晶振的类型,如SPXO、VCXO和TCXO,以及PCB设计中应注意的事项,如负载电容和热传导的影响。 有源晶振 有源晶…

后端开发学习敏捷需求--专题的目标与价值成效

专题的目标与价值成效 什么是专题公司或企业为了抓住业务机会或者解决痛点问题,而采取的具体的行动和举措专题的目标分析 1.业务调研了解目标的预期 利用5W2H来进行专题分析what——是什么?目的是什么?作什么工作?专题是什么 专题产生的背景是什么 专题的目标是什么,要达到…

大数据超全面入门干货知识,看这一篇就够了!

随着科技的飞速发展和互联网的普及,大数据已成为 21 世纪最炙手可热的话题之一。它像一面神秘的面纱,覆盖着现实世界,隐藏着无穷无尽的可能性。今天将带领大家一起揭开大数据这个未知世界的神秘面纱,带你了解大数据的概念、应用以及大数据相关组件。 一、什么是大数据大数据…

USB基础知识总结

USB基础知识总结 USB基本概念介绍 USB (Universal Serial Bus,通用串行总线)是1995年英特尔和微软等公司联合倡导发起的一种新的** PC 串行通信协议。它基于通用连接技术,实现外设的简单快速连接,达到方便用户、降低成本、扩展 PC 连接外设范围的目的。其最大特点是支持热插…