深度学习——优化器Optimizer

代码以及详细注释:

import torch
import torch.utils.data as Data
import torch.nn.functional as F
import matplotlib.pyplot as plt# torch.manual_seed(1)    # reproducible
"""超参数
"""
# 学习率
LR = 0.01
# 批大小
BATCH_SIZE = 32
# 轮次
EPOCH = 12# 造数据
x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size()))# # plot dataset
# plt.scatter(x.numpy(), y.numpy())
# plt.show()# put dateset into torch dataset
torch_dataset = Data.TensorDataset(x, y)
# 数据加载器
loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,)# default network
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.hidden = torch.nn.Linear(1, 20)   # hidden layerself.predict = torch.nn.Linear(20, 1)   # output layerdef forward(self, x):x = F.relu(self.hidden(x))      # activation function for hidden layerx = self.predict(x)             # linear outputreturn xif __name__ == '__main__':# 相同的网络结构net_SGD         = Net()net_Momentum    = Net()net_RMSprop     = Net()net_Adam        = Net()# 将上面的网络集成到这里nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]# 不同的优化器opt_SGD         = torch.optim.SGD(net_SGD.parameters(), lr=LR)opt_Momentum    = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)opt_RMSprop     = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)opt_Adam        = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))# 将上面的优化器集成到这里optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]# 损失函数loss_func = torch.nn.MSELoss()losses_his = [[], [], [], []]   # record loss# 训练轮次for epoch in range(EPOCH):print('Epoch: ', epoch)# 分批训练for step, (b_x, b_y) in enumerate(loader):          # for each training stepfor net, opt, l_his in zip(nets, optimizers, losses_his):output = net(b_x)              # get output for every netloss = loss_func(output, b_y)  # compute loss for every netopt.zero_grad()                # clear gradients for next trainloss.backward()                # backpropagation, compute gradientsopt.step()                     # apply gradientsl_his.append(loss.data)        # loss recoder# 绘图labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']for i, l_his in enumerate(losses_his):plt.plot(l_his, label=labels[i])plt.legend(loc='best')plt.xlabel('Steps')plt.ylabel('Loss')plt.ylim((0, 0.2))plt.show()

运行结果:

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/22203.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java中abstract关键字

文章目录 由来语法格式使用说明应用举例 由来 举例1: 随着继承层次中一个个新子类的定义,类变得越来越具体,而父类则更一般,更通用。类的设计应该保证父类和子类能够共享特征。有时将一个父类设计得非常抽象,以至于它…

【模式识别目标检测】——基于机器视觉的无人机避障RP-YOLOv3实例

目录 引入 一、YOLOv3模型 1、实时目标检测YOLOv3简介 2、改进的实时目标检测模型 二、数据集建立&结果分析 1、数据集建立 2、模型结果分析 三、无人机避障实现 参考文献: 引入 目前对于障碍物的检测整体分为:激光、红外线、超声波、雷达、…

【超全面】Linux嵌入式干货学习系列教程

文章目录 一、前言二、Linux基础篇三、数据结构与算法基础三、Linux应用篇四、Linux网络篇五、ARM篇六、Linux系统移植篇七、Linux驱动篇八、Linux特别篇九、Linux项目篇 一、前言 博主学习Linux也有几个月了,在这里为广大朋友整理出嵌入式linux的学习知识&#xff…

Matplotlib入门与实践(一)

Matplotlib 是一个 Python 的 2D绘图库,它以各种硬拷贝格式和跨平台的交互式环境生成出版质量级别的图形。通过 Matplotlib,开发者可以仅需要几行代码,便可以生成绘图,直方图,功率谱,条形图,错误…

oceanbase基础

与mysql对比 分布式一致性算法 paxos 存储结构(引擎)用的是两级的 数据库自动分片功能,提供独立的obproxy路由写入查询等操作到对应的分片 多租户 方便扩展 存储层 http://www.hzhcontrols.com/new-1391864.html LSM tree,is very…

渲染流程(上):HTML、CSS和JavaScript,是如何变成页面的?

在上一篇文章中我们介绍了导航相关的流程,那导航被提交后又会怎么样呢? 就进入了渲染阶段。这个阶段很重要,了解其相关流程能让你“看透”页面是如何工作的,有了这些知识,你可以解决一系列相关的问题,比如…

干货分享|SOLIDWORKS Composer如何解决缺失的actor?

​SOLIDWORKS Composer导入SOLIDWORKS模型,以便用户可以创建图形内容并与更广泛的受众共享项目。但是,有时模型导入时缺少Actor或组件,通常是由于在SOLIDWORKS中以轻量模式加载组件或Composer中的导入设置排除了曲面实体。 轻量模式 轻量模式…

学习C#基础知识和应用:

C#语言基础知识:了解C#的开发环境、变量、语法和程序结构等基础内容。这些知识是理解和开发C#自动化控制系统的前提。刚好,我这里有上位机入门,学习线路图,各种项目,需要留个6。 Winform窗体控件的应用:Wi…

WTM框架页面被其他网站引用免登录

用ASP.NET CORE开发通常都会有这样一个需求,自己框架开发的页面,要被其他网站嵌套引用,但其他网站通过链接到自己的开发页面的时候,通常会有一个登录页面,有的时候网站无缝集成的时候,这就会要求跳过这个WT…

Flutter:自定义错误显示

为什么要自定义错误处理 以下面数组越界的错误为例&#xff1a; class _YcHomeBodyState extends State<YcHomeBody> {List<String> list [苹果, 香蕉];overrideWidget build(BuildContext context) {return Center(child: Column(children: [Text(list[0]),Tex…

库迪身陷“价格”囹圄,融资苦难户还有突围的希望吗?

作者 | 心怡 来源 | 洞见新研社 三伏天已至&#xff0c;正是咖啡品牌借冰咖笑傲市场的好时机。没想到的是&#xff0c;靠低价狂奔的库迪却率先传出涨价的消息。 消息称&#xff0c;7月起&#xff0c;库迪划线价格上调1-2元&#xff0c;8.8元的团购价涨到9.9元&#xff0c;热门…

open3d 通过vscode+ssh连接远程服务器将可视化界面本地显示

当使用远程服务器时&#xff0c;我们希望能像在本地一样写完代码后能立刻出现一些gui窗口。但是目前网络上的资料都不能很好的解决这个问题。本文尝试尽可能简短地解决这个问题。 步骤 1、在服务器上安装open3d 已经非常简化了&#xff0c;可以使用一行代码完成 pip3 insta…