5-pytorch-torch.nn.Sequential()快速搭建神经网络

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
  • torch.nn.Sequential()快速搭建网络法
    • 1 生成数据
    • 2 快速搭建网络
    • 3 训练、输出结果
  • 总结


前言

本文内容还是基于4-pytorch前馈网络简单(分类)问题搭建这篇的相同例子,只是为了介绍另一种更加快速搭建网络的方法,看个人喜好用哪一种。
【注】:建议先看完上面链接的博客4,在来看本篇。
这里的这种搭建方法是使用**torch.nn.Sequential()**快速搭建,不用我们在继承重写net类了。

torch.nn.Sequential()快速搭建网络法

1 生成数据

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as npn_data = torch.ones(100,2)
x0 = torch.normal(2*n_data,1)
y0 = torch.zeros(100,1)
x1 = torch.normal(-2*n_data,1)
y1 = torch.ones(100,1)x = torch.cat((x0,x1),0)
# 在分类问题中标签必须用一维tensor,回归中则没有这个要求
y = torch.cat((y0,y1),0).reshape(-1)
# 在分类问题中标签还需要用torch.LongTensor类型
# 将张量 y 的类型转换为 long,这是因为在 PyTorch 中,分类问题的标签通常是整数类型(long),以便与模型输出的类别概率进行比较,从而计算损失。
y = y.long()fig = plt.figure()
plt.scatter(x.data.numpy()[:,0],x.data.numpy()[:,1],c=y.data.numpy())
# 给画出来的每一个点标上标签,有点难看,注了吧
# 循环遍历每个数据点,根据其对应的标签添加标签文本
for i in range(len(x)):plt.text(x[i][0], x[i][1], str(int(y[i].item())), fontsize=8)
plt.show()

输出:
在这里插入图片描述

2 快速搭建网络

## 搭建网络method1
# class Net(torch.nn.Module):
#     def __init__(self,n_features,n_hidden,n_output):
#         # 继承原来结构体的全部init属性及方法
#         super(Net,self).__init__()
#         # 线性层就是全连接层
#         self.hidden = torch.nn.Linear(n_features,n_hidden)
#         self.predict = torch.nn.Linear(n_hidden,n_output)
#         
#     def forward(self,x):
#         # 重写继承类的向前传播方法,就是在这个里面选择激活函数的
#         x = F.relu(self.hidden(x))
#         # 分类中输出层也可以不用激活函数,我们最后在对输出结果进行softmax处理
#         x = self.predict(x)
#         return x
#         
# net = Net(2,10,2)
# # 输出层定义2个输出,对输出在进行softmax处理,取出概率最大的元素的下标就是我们分类的类别;与回归有所不同
# # 有点类似机器学习里面的独热编码
# print(net)## 快速搭建法,和前面注释掉的效果是一样的。
net = torch.nn.Sequential(torch.nn.Linear(2,10),torch.nn.ReLU(), # 这里激活函数大写了要torch.nn.Linear(10,2)
)
print(net)

输出:
在这里插入图片描述

3 训练、输出结果

optimizer = torch.optim.SGD(net.parameters(),lr=0.02)
# 分类用交叉熵损失函数
loss_func = torch.nn.CrossEntropyLoss()# 开启matplotlib的交换模式
plt.ion()
for t in range(100):# 这一步其实是调用了类里面的 __call__魔术方法,又学到一个魔术方法out = net(x)loss = loss_func(out,y)# 梯度清零optimizer.zero_grad()# 误差反向传播,求梯度loss.backward()# 进行优化器优化optimizer.step()if t%5 == 0:plt.cla()prediction = torch.max(F.softmax(out,1),1)[1]pred_y = prediction.data.numpy().reshape(-1)target_y = y.data.numpy().reshape(-1)plt.scatter(x.data.numpy()[:,0],x.data.numpy()[:,1],c=pred_y)accuracy = sum(pred_y==target_y)/200plt.text(1.2,-4,'accuracy=%.2f' % accuracy, fontdict={'size':20,'color':'red'})plt.pause(0.1)
# 关闭matplotlib的交换模式
plt.ioff()
plt.show()

输出:
在这里插入图片描述

# 输出out经softmax处理过后才变成概率
out2probability = F.softmax(out,1)
#print(out2probability.round(decimals=2))
# 取出概率向量里面概率最大的下标就是最终的分类结果
prediction = torch.max(F.softmax(out,1),1)[1]
print(prediction)在这里插入代码片

输出:
在这里插入图片描述

总结

选择那种方法搭建,看个人喜好,效果完全一样。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/625499.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

南京大学田兴军团队在土壤微生物的功能类群以及在中国土壤生态系统的分解潜力取得进展(两篇)

研究成果一:土壤微生物群落功能是由占据不同生态位的、高度多样化的微生物之间发生的复杂相互作用决定的。这些微生物相互作用对环境变化的反应敏感,并且往往在土壤“暗箱”中悄然发生,故极大地限制了我们对微生物功能的理解。本研究基于微生…

【DM8】ODBC

官网下载ODBC https://www.unixodbc.org/ 上传到linux系统中 /mnt下 [rootstudy ~]#cd /mnt [rootstudy mnt]# tar -zxvf unixODBC-2.3.12.tar.gz [rootstudy mnt]# cd unixODBC-2.3.12/ [rootstudy unixODBC-2.3.12]# ./configure 注意:若是报以上错 则是gcc未安…

文件服务: txt文件预览乱码问题

文章目录 一、背景二、解决方案1、转换流(解决代码与文件编码不一致读取乱码的问题)2、获取文本文件的字符编码 一、背景 在springboot项目中使用springmvc web.resources的形式进行文件访问。本地上传txt文件编码格式为GB2312(中文简体),浏…

【蓝桥杯嵌入式】串口通信与RTC时钟

【蓝桥杯嵌入式】串口通信与RTC时钟 串口通信cubemx配置串口通信程序设计 RTC时钟cubemx配置程序设计 串口通信 cubemx配置 打开串口通信,并配置波特率为9600 打开串口中断 重定义串口接收与发送引脚,默认是PC4,PC5,需要改为P…

计算股价波动率python

上述图片上传gemini,提问:转换为python代码 好的,以下是您发送的图像中公式的 Python 代码: python def stock_volatility(prices, opening_prices, N): """ 计算股票价格的波动率。 参数: p…

Three.js加载glb / gltf模型,Vue加载glb / gltf模型(如何在vue中使用three.js,vue使用threejs加载glb模型)

简介:Three.js 是一个用于在 Web 上创建和显示 3D 图形的 JavaScript 库。它提供了丰富的功能和灵活的 API,使开发者可以轻松地在网页中创建各种 3D 场景、模型和动画效果。可以用来展示产品模型、建立交互式场景、游戏开发、数据可视化、教育和培训等等…

AI决策与专家决策,您更喜欢哪种决策方式?

HI,我是AI智能小助手CoCo。 CoCode开发云智能助手CoCo “大家好,我是CoCode开发云的AI智能小助手CoCo,现在为大家播放关于CoCode开发云AI大家庭的最新消息: 欢迎AI家庭新成员:AI自动决策”。 AI自动决策发布 CoCode开…

【数据结构1-基本概念和术语】

这里写自定义目录标题 0.数据,数据元素,数据项,数据对项,数据结构,逻辑结构,存储结构1.结构1.1逻辑结构1.2存储结构1.2.1 顺序结构1.2.2链式结构 1.3数据结构1.3.1基本数据类型1.3.2抽象数据类型1.3.2.1一个…

谷粒商城part2——环境篇

这里是过来人的学习建议: 1、如有条件电脑内存至少16G起步,条件进一步加个屏幕,条件更进一步租一台至少4G内存的X86架构云服务器,所有部署的东西全扔云服务器上 2、P16,P17没法搭起来的建议照着rerenfast的github上的教…

超高效空气过滤器(ULPA)在半导体制造领域需求旺盛 滤芯为其重要组成部分

超高效空气过滤器(ULPA)在半导体制造领域需求旺盛 滤芯为其重要组成部分 超高效空气过滤器(ULPA)又称超低穿透率空气过滤器,指含有超高效过滤网,对0.1微米粒子捕集效率在99.999%以上的空气过滤器。与高效空…

NineData正式将SQL开发正式升级为数据库DevOps

NineData SQL 开发早期主要提供 SQL 窗口(IDE)功能,产品经过将近两年时间的打磨,新增了大量的企业级功能,时至今日已经服务了上万开发者,覆盖了数据库设计、开发、测试、变更等生命周期的功能。 为了让企业…

uni-app中页面生命周期与vue生命周期的执行顺序对比

应用生命周期 uni-app 支持如下应用生命周期函数: 函数名说明平台兼容onLaunch当uni-app 初始化完成时触发(全局只触发一次),参数为应用启动参数,同 uni.getLaunchOptionsSync 的返回值onShow当 uni-app 启动&#x…