【动手学习深度学习--逐行代码解析合集】10Dropout暂退法

【动手学习深度学习】逐行代码解析合集

10Dropout暂退法


视频链接:动手学习深度学习–Dropout暂退法
课程主页:https://courses.d2l.ai/zh-v2/
教材:https://zh-v2.d2l.ai/

1、暂退法原理

在这里插入图片描述
在这里插入图片描述

2、从零开始实现暂退法

import torch
from torch import nn
from d2l import torch as d2limport os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"# 该函数以dropout的概率丢弃张量输入X中的元素
def dropout_layer(X, dropout):assert 0 <= dropout <= 1# 在本情况中,所有元素都被丢弃if dropout == 1:return torch.zeros_like(X)# 在本情况中,所有元素都被保留if dropout == 0:return X# torch.rand(X.shape)生成0-1之间的均匀随机分布,大于dropout的返回1,小于的返回0mask = (torch.rand(X.shape) > dropout).float()# mask随机生成0或1return mask * X / (1.0 - dropout)
# 测试dropout_layer函数,暂退概率分别为0、0.5和1。
X=  torch.arange(16, dtype = torch.float32).reshape((2, 8))
print(X)
print(dropout_layer(X, 0.))
print(dropout_layer(X, 0.5))
print(dropout_layer(X, 1.))

运行结果
在这里插入图片描述

2.1 定义模型参数

# 定义具有两个隐藏层的多层感知机,每个隐藏层包含256个单元。
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256

2.2 定义模型

我们可以将暂退法应用于每个隐藏层的输出(在激活函数之后), 并且可以为每一层分别设置暂退概率: 常见的技巧是在靠近输入层的地方设置较低的暂退概率。 下面的模型将第一个和第二个隐藏层的暂退概率分别设置为0.2和0.5, 并且暂退法只在训练期间有效。

# 定义具有两个隐藏层的多层感知机,每个隐藏层包含256个单元。
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256
# 模型将第一个和第二个隐藏层的暂退概率分别设置为0.2和0.5
dropout1, dropout2 = 0.2, 0.5class Net(nn.Module):# is_training = True:给程序标注是在训练def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,is_training = True):super(Net, self).__init__()self.num_inputs = num_inputsself.training = is_trainingself.lin1 = nn.Linear(num_inputs, num_hiddens1)  # 第一个隐藏层self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)  # 第二个隐藏层self.lin3 = nn.Linear(num_hiddens2, num_outputs)  # 输出层self.relu = nn.ReLU()  # 激活函数def forward(self, X):# 对第一个隐藏层作非线性激活后,再使用dropoutH1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))# 只有在训练模型时才使用dropoutif self.training == True:# 在第一个全连接层之后添加一个dropout层H1 = dropout_layer(H1, dropout1)# 对第二个隐藏层作非线性激活H2 = self.relu(self.lin2(H1))if self.training == True:# 在第二个全连接层之后添加一个dropout层H2 = dropout_layer(H2, dropout2)# 输出层不作用dropoutout = self.lin3(H2)return outnet = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)

2.3 训练和测试

# 训练和测试
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show()

在这里插入图片描述

若不使用dropout对比结果(此处将dropout1, dropout2 = 0.0, 0.0)
在这里插入图片描述

3、暂退法的简洁实现

# 简洁实现
net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),  # 第一个隐藏层nn.ReLU(),  # Dropout放在ReLU前后均可# 在第一个全连接层之后添加一个dropout层nn.Dropout(dropout1),nn.Linear(256, 256),  # 第二个隐藏层nn.ReLU(),# 在第二个全连接层之后添加一个dropout层nn.Dropout(dropout2),nn.Linear(256, 10))   # 输出层# 初始化权重,此处不懂可看05softmax回归的简洁实现
def init_weights(m):if type(m) == nn.Linear:# m.weight默认为0,以均值为0方差为0.01来随机初始化权重nn.init.normal_(m.weight, std=0.01)
# net.apply(init_weights)会递归地将函数init_weights应用到父模块的每个子模块submodule,也包括model这个父模块自身。
net.apply(init_weights);# 参数更新
trainer = torch.optim.SGD(net.parameters(), lr=lr)
# 训练画图
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/13168.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux 创建一个线程的基础开销探讨

测试代码 测试方法比较笨&#xff0c;每修改一次线程数&#xff0c;就重新编译一次&#xff0c;再运行。在程序运行过程中&#xff0c;查看到进程 pid&#xff0c;然后通过以下命令查看进程的运行状态信息输出到以线程数为名字的日志文件中&#xff0c;最后用 vimdiff 对比文件…

chatglm docker镜像,一键部署chatglm本地知识库

好久没有写文章了&#xff0c;今天有空&#xff0c;记录一下chatglm本地知识库的docker镜像制作过程。 核心程序是基于“闻达”开源项目&#xff0c;稍作改动。镜像可以直接启动运行&#xff0c;大家感兴趣可以进入镜像内部查看&#xff0c;代码位于 /app 目录下。 一、制作镜…

多元分类预测 | Matlab全连接神经网络(DNN)分类预测,多特征输入模型

文章目录 效果一览文章概述部分源码参考资料效果一览 文章概述 全连接神经网络(DNN)分类预测,多特征输入模型 多特征输入单输出的二分类及多分类模型。程序内注释详细,直接替换数据就可以用。程序语言为matlab,程序可出分类效果图,迭代优化图,混淆矩阵图。 部分源码

SpringBoot配置外部Tomcat项目启动流程源码分析

前言 SpringBoot应用默认以Jar包方式并且使用内置Servlet容器(默认Tomcat)&#xff0c;该种方式虽然简单但是默认不支持JSP并且优化容器比较复杂。故而我们可以使用习惯的外置Tomcat方式并将项目打War包。 【1】创建项目并打War包 ① 同样使用Spring Initializer方式创建项目 …

【kafka面试题2】如何保证kafka消息的顺序性

【kafka面试题】如何保证kafka消息的顺序性 一、整体策略 如何保证kafka消息的顺序性呢&#xff0c;其实整体的策略就是&#xff1a;我们让需要有序的消息发送到同一个分区Partition。 为什么说让有序的消息发送到同一个分区Partition就行呢&#xff0c;&#xff0c;下面我们…

Python学习笔记(十六)————异常相关

目录 &#xff08;1&#xff09;异常概念 &#xff08;2&#xff09;异常的捕获 ①异常捕获的原因 ②捕获常规异常 ③捕获指定异常 ④捕获多个异常 ⑤ 捕获异常并输出描述信息 ⑥捕获所有异常 ⑦异常else ⑧异常的finally &#xff08;3&#xff09;异常的传递 &#xff08…

Idea社区版创建SpringBoot

一 下载Spring Initalizr and Assistant插件 选择左上角的File->Settings->Plugins&#xff0c;在搜索框中输入Spring&#xff0c;出现的第一个Spring Boot Helper插件&#xff0c;点击Installed&#xff0c;下载插件。&#xff08;这里已经下载&#xff09; 二 创建Spr…

【MySQL练习及单表查询】

一、MySQL练习 一.创建表&#xff1a; 创建员工表employee&#xff0c;字段如下&#xff1a; id&#xff08;员工编号&#xff09; name&#xff08;员工名字&#xff09; gender&#xff08;员工性别&#xff09; salary&#xff08;员工薪资&#xff09; 二.插入数据 1&…

【Windows】Redis单机部署

下载redis 下载地址&#xff1a;Releases microsoftarchive/redis GitHub 1、下载后解压&#xff0c;在文件根目录下创建两个文件夹dbcache、logs 修改配置文件redis.windows.conf &#xff08;1&#xff09;配置redis地址&#xff1a; bind 127.0.0.1 &#xff08;2&am…

Redis常见数据结构

文章目录 前言一、Redis通用命令二、String类型三、Key的层级结构四、Hash类型五、List类型六、Set类型七、SortedSet类型 前言 Redis是一个key-value的数据库&#xff0c;key一般是String类型&#xff0c;但是value的类型多种多样 在学习Redis不同数据类型时&#xff0c;我们…

AIGC - Stable Diffusion 图像控制插件 ControlNet (OpenPose) 配置与使用

欢迎关注我的CSDN&#xff1a;https://spike.blog.csdn.net/ 本文地址&#xff1a;https://spike.blog.csdn.net/article/details/131591887 论文&#xff1a;Adding Conditional Control to Text-to-Image Diffusion Models ControlNet 是神经网络结构&#xff0c;用于控制预…

Vision Pro销售策略曝光,面罩/头带/屈光镜片加大零售难度

彭博社Mark Gurman再次发布了关于苹果Vision Pro的销售策略&#xff0c;以及零售方面的难题。 一、销售计划和策略 1&#xff0c;2024年初先在美国部分门店销售&#xff0c;仅线下购买&#xff0c;线上暂不开放。购买方式是先线上预约&#xff08;可能要提供面部扫描图、眼镜…