机器学习优化器和SGD和SGDM实验对比(编程实现SGD和SGDM)

机器学习优化器和SGD和SGDM实验对比

博主最近在学习优化器,于是呢,就做了一个SGD和SGDM的实验对比,可谓是不做不知道,一做吓一跳,这两个算法最终对结果的影响还是挺大的,在实验中SGDM明星要比SGD效果好太多了,SGD很容易陷入局部最优,而且非常容易发生梯度爆炸的情况,而SGDM做的实验当中还未出现这些情况。
在这次实验中博主发现了很多很多的特点对于SGDM和SGDM,下面博主列出这次实验的收获。
(1)SGDM相比SGD当拥有同样的学习率,SGDM更不容易发生梯度爆炸,SGD对于学习率的要求很高,大了,就会梯度爆炸,小了迭代特别慢。
(2)在本次此实验中,我们可以发现,小批量梯度下降比单个样本进行梯度下降区别极为大,单个样本做梯度下降时,特别容易发生梯度爆炸,模型不易收敛。
(3)SGDM相比SGD,loss下降曲线更加平稳,也不易陷入局部最优,但是他的训练较慢,可以说是非常慢了。
(4)超参数的设置对于这两个模型的影响都是很大的,要小心处理。
(5)数据集对于模型迭代也有很大影响,注意要对数据集进行适当的处理。
(6)随着训练轮次的增多,SGDM相比SGD更有可能取得更好的效果。

下面让我们看一看代码:

#coding=gbkimport torch
from torch.autograd import Variable
from torch.utils import data
import matplotlib.pyplot as pltX =torch.randn(100,4)
w=torch.tensor([1,2,3,4])Y =torch.matmul(X, w.type(dtype=torch.float))  + torch.normal(0, 0.1, (100, ))+6.5
print(Y)
Y=Y.reshape((-1, 1))#将X,Y转成200 batch大小,1维度的数据def loss_function(w,x,y,choice,b):if choice==1:return torch.abs(torch.sum(w@x)+b-y)else:#    print("fdasf:",torch.sum(w@x),y)#   print(torch.pow(torch.sum(w@x)-y,2))return torch.pow(torch.sum(w@x)-y+b,2)
index=0batch=32
learning_rating=0.03def SGDM(batch,step,beta,grad_s,grad_b_s):if step==0:grad=Variable(torch.tensor([0.0]),requires_grad=True)grad_b=Variable(torch.tensor([0.0]),requires_grad=True)loss=Variable(torch.tensor([0.0]),requires_grad=True)for j in range(batch):try:#  print(w,X[index],Y[index],b,)#    print(loss_function(w,X[index],Y[index],b,2))#  print(torch.sum(w@X[index]),Y[index])grad=(torch.sum(w@X[index])-Y[index]+b)*(-1)*X[index]+gradgrad_b=(torch.sum(w@X[index])-Y[index]+b)*(-1)+grad_b#   print(loss_function(w,X[index],Y[index],2,b))loss=loss_function(w,X[index],Y[index],2,b)+lossindex=index+1except:index=0return  grad/batch,loss/batch,grad_b/batchelse:grad=Variable(torch.tensor([0.0]),requires_grad=True)grad_b=Variable(torch.tensor([0.0]),requires_grad=True)loss=Variable(torch.tensor([0.0]),requires_grad=True)for j in range(batch):try:#  print(w,X[index],Y[index],b,)#    print(loss_function(w,X[index],Y[index],b,2))#  print(torch.sum(w@X[index]),Y[index])grad=(torch.sum(w@X[index])-Y[index]+b)*(-1)*X[index]+gradgrad_b=(torch.sum(w@X[index])-Y[index]+b)*(-1)+grad_bloss=loss_function(w,X[index],Y[index],2,b)+lossindex=index+1except:index=0return  (beta*grad_s+(1-beta)*grad)/batch,loss/batch,(beta*grad_b_s+(1-beta)*grad_b)/batchdef train(n):loss_list=[]setp=0global grad,grad_bgrad=0grad_b=0while n:n=n-1grad,loss,grad_b=SGDM(batch,setp,0.99,grad,grad_b)setp=setp+1# print(grad,loss,grad_b)w.data=w.data+learning_rating*grad*w.datab.data=b.data+learning_rating*grad_b# print("b",b)#print("grad_b",grad_b)#print("w:",w)#print("loss:",loss)#print("b:",b)loss_list.append(float(loss))#  b.data=b.data-(lear#  ning_rating*b.grad.data)#   print("b",b)print("w:",w)print("b:",b)print("loss:",loss)return loss_listdef SGD(batch):grad=Variable(torch.tensor([0.0]),requires_grad=True)grad_b=Variable(torch.tensor([0.0]),requires_grad=True)loss=Variable(torch.tensor([0.0]),requires_grad=True)for j in range(batch):try:#  print(w,X[index],Y[index],b,)#    print(loss_function(w,X[index],Y[index],b,2))#  print(torch.sum(w@X[index]),Y[index])grad=(torch.sum(w@X[index])-Y[index]+b)*(-1)*X[index]+gradgrad_b=(torch.sum(w@X[index])-Y[index]+b)*(-1)+grad_bloss=loss_function(w,X[index],Y[index],2,b)+lossindex=index+1except:index=0return  grad/batch,loss/batch,grad_b/batchdef train_s(n):loss_list=[]while n:if n//100==0:print(n)n=n-1grad,loss,grad_b=SGD(batch)# print(grad,loss,grad_b)w.data=w.data+learning_rating*grad*w.datab.data=b.data+learning_rating*grad_b# print("b",b)#print("w:",w)#print("loss:",loss)#print("b:",b)#  b.data=b.data-(learning_rating*b.grad.data)#   print("b",b)loss_list.append(float(loss))print("w:",w)print("b:",b)print("loss:",loss)return loss_listw=torch.tensor([1,1.0,1,1])b=torch.tensor([1.0])
w=Variable(w,requires_grad=True)b=Variable(b,requires_grad=True)epoch=10000
epoch_list=list(range(1,epoch+1))loss_list=train(epoch)plt.plot(epoch_list,loss_list,label='SGDM')#SGD
w=torch.tensor([1,1.0,1,1])b=torch.tensor([1.0])
w=Variable(w,requires_grad=True)b=Variable(b,requires_grad=True)
print(w)epoch_list=list(range(1,epoch+1))loss_list=train_s(epoch)plt.plot(epoch_list,loss_list,label='SGD')
plt.legend()
plt.show()

下面是一张跑出的实验图,事实上,我做了很多很多的实验,这是一件十分有趣的事情,在实验中,你可以看到这些优化器的特点,这很有趣,当然前提是这个优化器是你自己完全编程写的。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/7491.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

springBoot配置多环境

在代码中一般有3个环境,为了避免频繁的每次上线需要手动更改环境的问题。 test 本地测试环境,代码调试的 dev 服务端开发环境-用来验证用 prod 服务端正式环境 我创建2个做示例,里面写的不同配置 点运行的项目会有一个Edit Configurations…

​浅谈大型语言模型

大型语言模型(Large Language Models,LLMs)是一类强大的人工智能模型,具有出色的自然语言处理能力。它们在许多任务中表现出色,如机器翻译、文本摘要、对话生成和情感分析等。下面我们将介绍大型语言模型的训练和生成过…

智谱AI-算法实习生(知识图谱方向)实习面试记录

岗位描述 没错和我的经历可以说是match得不能再match了,但是还是挂了hh。 面试内容 给我面试的是唐杰老师的博士生,方向是社交网络数据挖掘,知识图谱。不cue名了,态度很友好的 ,很赞。 date:6.28 Q1 自…

【码银送书第一期】通用人工智能:初心与未来

目录 前言 正文 内容简介 作者简介 译者简介 目录 前言 自20世纪50年代图灵在其划时代论文《计算机器与智能》中提出“图灵测试”以及之后的达特茅斯研讨会开始,用机器来模仿人类学习及其他方面的智能,即实现“人工智能”(Artificial …

ORA-31664: unable to construct unique job name when defaulted

某个环境备份不足空间问题处理后,手动执行expdp备份的脚本,报错如下 Export: Release 11.2.0.4.0 - Production on Tue Jul 4 11:46:14 2023 Copyright (c) 1982, 2011, Oracle and/or its affiliates. All rights reserved. Connected to: Oracle D…

获取移动设备的电池信息

通过BatteryManager来获取关于电池的信息 实例 package com.example.softwarepatentdemo;import android.content.BroadcastReceiver; import android.content.Context; import android.content.Intent; import android.content.IntentFilter; import android.os.BatteryManag…

2023年大学计算机专业实习心得14篇

2023年大学计算机专业实习心得精选篇1 20__年已然向我们挥手告别而去了。在20__年初之际,让我们对过去一年的工作做个总结。忙碌的一年里,在领导及各位同事的帮助下,我顺利的完成了20__年的工作。为了今后更好的工作,总结经验&…

Docker|kubernetes|本地镜像批量推送到Harbor私有仓库的脚本

前言: 可能有测试环境,而测试环境下有N多的镜像,需要批量导入到自己搭建的Harbor私有仓库内,一般涉及到批量的操作,自然还是使用脚本比较方便。 本文将介绍如何把某个服务器的本地镜像 推送到带有安全证书的私有Harb…

GAMES101笔记 Lecture07 Shading1(Illumination, Shading and Graphics Pipeline)

目录 Visibility / Occlusion(可见性 or 遮挡)Painters Algorithm(画家算法)Z-Buffer(深度缓冲算法) Shading(着色)A Simple Shading Model(Blinn-Phong Reflectance Model)一个简单的着色模型:Blinn-Phong反射模型Diffuse Reflection(漫反射) 参考资源 Visibility …

性能测试该怎么做,终于找到方法了

目录 开头 分类 服务器与场景设计 计算TPS 设计场景 场景运用 单交易最大压力: 单交易稳定性: 混合场景稳定性: 业务指标: 数据库 中间件 负载均衡: 最后: 开头 性能测试的工具有很多&#xf…

FreeRTOS学习笔记—基础知识

文章目录 一、什么是RTOS二、前后台系统三、实时内核(可剥夺型内核)四、RTOS系统五、FreeRTOS系统简介六、FreeRTOS源码下载 一、什么是RTOS RTOS全称为:Real Time OS,就是实时操作系统,核心在于实时性。实时操作系统又分为硬实时…

Servlet

1.Servlet是什么 Servlet是一种实现动态页面的技术。是一组Tomcat提供给程序员的API,帮助程序员简单高效的开发一个web app 回顾 动态页面 VS 静态页面 静态页面也就是内容固定的页面,即使 用户不同/时间不同/输入参数不同,页面的内容也不…