前馈神经网络正则化例子

直接看代码:

import torch  
import numpy as np  
import random  
from IPython import display  
from matplotlib import pyplot as plt  
import torchvision  
import torchvision.transforms as transforms   mnist_train = torchvision.datasets.MNIST(root='/MNIST', train=True, download=True, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor())  batch_size = 256 train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  num_inputs,num_hiddens,num_outputs =784, 256,10def init_param():W1 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens,num_inputs)), dtype=torch.float32)  b1 = torch.zeros(1, dtype=torch.float32)  W2 = torch.tensor(np.random.normal(0, 0.01, (num_outputs,num_hiddens)), dtype=torch.float32)  b2 = torch.zeros(1, dtype=torch.float32)  params =[W1,b1,W2,b2]for param in params:param.requires_grad_(requires_grad=True)  return W1,b1,W2,b2def relu(x):x = torch.max(input=x,other=torch.tensor(0.0))  return xdef net(X):  X = X.view((-1,num_inputs))  H = relu(torch.matmul(X,W1.t())+b1)  #myrelu =((matmal x,w1)+b1),return  matmal(myrelu,w2 )+ b2return relu(torch.matmul(H,W2.t())+b2 )return torch.matmul(H,W2.t())+b2def SGD(paras,lr):  for param in params:  param.data -= lr * param.grad  def l2_penalty(w):return (w**2).sum()/2def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr=None,optimizer=None,mylambda=0):  train_ls, test_ls = [], []for epoch in range(num_epochs):ls, count = 0, 0for X,y in train_iter :X = X.reshape(-1,num_inputs)l=loss(net(X),y)+ mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)optimizer.zero_grad()l.backward()optimizer.step()ls += l.item()count += y.shape[0]train_ls.append(ls)ls, count = 0, 0for X,y in test_iter:X = X.reshape(-1,num_inputs)l=loss(net(X),y) + mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)ls += l.item()count += y.shape[0]test_ls.append(ls)if(epoch)%2==0:print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))return train_ls,test_lslr = 0.01num_epochs = 20Lamda = [0,0.1,0.2,0.3,0.4,0.5]Train_ls, Test_ls = [], []for lamda in Lamda:print("current lambda is %f"%lamda)W1,b1,W2,b2 = init_param()loss = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD([W1,b1,W2,b2],lr = 0.001)train_ls, test_ls = train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr,optimizer,lamda)   Train_ls.append(train_ls)Test_ls.append(test_ls)x = np.linspace(0,len(Train_ls[1]),len(Train_ls[1]))plt.figure(figsize=(10,8))for i in range(0,len(Lamda)):plt.plot(x,Train_ls[i],label= f'L2_Regularization:{Lamda [i]}',linewidth=1.5)plt.xlabel('different epoch')plt.ylabel('loss')plt.legend(loc=2, bbox_to_anchor=(1.1,1.0),borderAxesPad = 0.)plt.title('train loss with L2_penalty')plt.show()

运行结果:

在这里插入图片描述

疑问和心得:

  1. 画图的实现和细节还是有些模糊。
  2. 正则化系数一般是一个可以根据算法有一定变动的常数。
  3. 前馈神经网络中,二分类最后使用logistic函数返回,多分类一般返回softmax值,若是一般的回归任务,一般是直接relu返回。
  4. 前馈神经网络的实现,从物理层上应该是全连接的,但是网上的代码一般都是两层单个神经元,这个容易产生误解。个人感觉,还是要使用nn封装的函数比较正宗。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/73760.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

开源数据库Mysql_DBA运维实战 (备份与还原)

Mysql数据库的备份与还原🍃 备份对于数据库而言是至关重要的。当数据文件发生损坏、MySQL服务出现错误、系统内核崩溃、计算机硬件损坏或者数据被误删等事件时,使用一种有效的数据备份方案,就可以快速解决以上所有的问题。MySQL提供了多种备…

unity 之Transform组件(汇总)

文章目录 理论指导结合例子 理论指导 当在Unity中处理3D场景中的游戏对象时,Transform 组件是至关重要的组件之一。它管理了游戏对象的位置、旋转和缩放,并提供了许多方法来操纵和操作这些属性。以下是关于Transform 组件的详细介绍: 位置&a…

Flink之Task解析

Flink之Task解析 对Flink的Task进行解析前,我们首先要清楚几个角色TaskManager、Slot、Task、Subtask、TaskChain分别是什么 角色注释TaskManager在Flink中TaskManager就是一个管理task的进程,每个节点只有一个TaskManagerSlotSlot就是TaskManager中的槽位,一个TaskManager中可…

蓝光眼镜有效吗?科研团队:无法证明防蓝光镜片可以减少视力伤害

8 月 19 日消息,本次由墨尔本大学、莫纳什大学和伦敦城市大学联合进行的科研团队,对来自 6个国家和地区的 17 项已发表的研究进行了深入研究。他们的研究发现,无法证明防蓝光镜片能够减少眼睛的视力伤害或改善佩戴者的睡眠质量等功效。 这项研…

springboot+Vue--打基础升级--(二)写个主菜单导航界面

1. 华为OD机考题 答案 2023华为OD统一考试(AB卷)题库清单-带答案(持续更新) 2023年华为OD真题机考题库大全-带答案(持续更新) 2. 面试题 一手真实java面试题:2023年各大公司java面试真题汇总--…

【李沐】3.2线性回归从0开始实现

%matplotlib inline import random import torch from d2l import torch as d2l1、生成数据集: 看最后的效果,用正态分布弄了一些噪音 上面这个具体实现可以看书,又想了想还是上代码把: 按照上面生成噪声,其中最后那…

YOLOv8改进后效果

数据集 自建铁路障碍数据集-包含路障,人等少数标签。其中百分之八十作为训练集,百分之二十作为测试集 第一次部署 版本:YOLOv5 训练50epoch后精度可达0.94 mAP可达0.95.此时未包含任何改进操作 第二次部署 版本:YOLOv8改进版本 首…

Spring Boot 知识集锦之Spring-Batch批处理组件详解

文章目录 0.前言1.参考文档2.基础介绍2.1. 核心组件 3.步骤3.1. 引入依赖3.2. 配置文件3.3. 核心源码 4.示例项目5.总结 0.前言 背景: 一直零散的使用着Spring Boot 的各种组件和特性,从未系统性的学习和总结,本次借着这个机会搞一波。共同学…

从零实战SLAM-第四课(相机成像及常用视觉传感器)

在七月算法报的班,老师讲的蛮好。好记性不如烂笔头,关键内容还是记录一下吧,课程入口,感兴趣的同学可以学习一下。 --------------------------------------------------------------------------------------------------------…

如何学习专业的学术用语01

问题的提出——凭啥人家写的词汇这么专业 做法一 做法二:做一个专业数据库 专门做教育技术类的

Java进阶篇--迭代器模式

目录 同步迭代器(Synchronous Iterator): Iterator 接口 常用方法: 注意: 扩展小知识: 异步迭代器(Asynchronous Iterator): 常用的方法 注意: 总结&#xff1a…

【制作npm包4】api-extractor 学习

制作npm包目录 本文是系列文章, 作者一个橙子pro,本系列文章大纲如下。转载或者商业修改必须注明文章出处 一、申请npm账号、个人包和组织包区别 二、了解 package.json 相关配置 三、 了解 tsconfig.json 相关配置 四、 api-extractor 学习 五、npm包…

【云计算原理及实战】初识云计算

该学习笔记取自《云计算原理及实战》一书,关于具体描述可以查阅原本书籍。 云计算被视为“革命性的计算模型”,因为它通过互联网自由流通使超级计算能力成为可能。 2006年8月,在圣何塞举办的SES(捜索引擎战略)大会上&a…

万宾燃气管网监测解决方案,守护城市生命线安全

方案背景 城市燃气管网作为连接天然气长输管线与天然气用户的桥梁,担负着向企业和居民用户直接供气的重要职责。随着城市燃气需求的急剧增加,城市燃气管网规模日趋庞大,安全隐患和风险也随之增加。目前,我国燃气管网的运行仍存在…

OLED透明屏采购指南:如何选择高质量产品?

着科技的不断进步,OLED透明屏作为一种创新的显示技术,在各个行业中得到了广泛应用。 在进行OLED透明屏采购时,选择高质量的产品至关重要。在这篇文章中,尼伽将为您提供一个全面的OLED透明屏采购指南,帮助您了解关键步…

Beats:使用 Filebeat 将 golang 应用程序记录到 Elasticsearch - 8.x

毫无疑问,日志记录是任何应用程序最重要的方面之一。 当事情出错时(而且确实会出错),我们需要知道发生了什么。 为了实现这一目标,我们可以设置 Filebeat 从我们的 golang 应用程序收集日志,然后将它们发送…

【CSS动画02--卡片旋转3D】

CSS动画02--卡片旋转3D 介绍代码HTMLCSS css动画02--旋转卡片3D 介绍 当鼠标移动到中间的卡片上会有随着中间的Y轴进行360的旋转&#xff0c;以下是几张图片的介绍&#xff0c;上面是鄙人自己录得一个供大家参考的小视频&#x1f92d; 代码 HTML <!DOCTYPE html>…

Android Stodio编译JNI项目,Cmake出错:Detecting C compiler ABI info - failed

在使用Android Stodio编译JNI项目时出现Cmake错误&#xff0c;报错如下&#xff1a; Execution failed for task :app:configureCMakeDebug[arm64-v8a]. > [CXX1429] error when building with cmake using C:\Users\Dell\AndroidStudioProjects\MyApplication2\app\src\ma…

ssm医院门诊挂号系统源码和论文PPT

ssm医院门诊挂号系统源码和论文PPT008 开题报告 任务书 源码 数据库sql 论文 开发环境&#xff1a; 开发工具&#xff1a;idea 数据库mysql5.7(mysql5.7最佳) 数据库链接工具&#xff1a;navcat,小海豚等 开发技术&#xff1a;java ssm tomcat8.5 1.选题的背景和意义 …

UVC摄像头

1 版本历史 1.1 UVC uvc_version UVC 1.0: Sep-4-2003 UVC 1.1: Jun-1-2005 UVC 1.5: August-9-2012, H.264 video codec. Linux 4.5 introduces UVC 1.5, but does not support H264. 1.2 V4L版本历史 Video4Linux取名的灵感来自1992 Video for Windows&#xff08;V4W&#x…