神经网络学习小记录74——Pytorch 设置随机种子Seed来保证训练结果唯一

神经网络学习小记录74——Pytorch 设置随机种子Seed来保证训练结果唯一

  • 学习前言
  • 为什么每次训练结果不同
  • 什么是随机种子
  • 训练中设置随机种子

学习前言

好多同学每次训练结果不同,最大的指标可能会差到3-4%这样,这是因为随机种子没有设定导致的,我们一起看看怎么设定吧。
在这里插入图片描述

为什么每次训练结果不同

模型训练中存在很多随机值,最常见的有:
1、随机权重,网络有些部分的权重没有预训练,它的值则是随机初始化的,每次随机初始化不同会导致结果不同。
2、随机数据增强,一般来讲网络训练会进行数据增强,特别是少量数据的情况下,数据增强一般会随机变化光照、对比度、扭曲等,也会导致结果不同。
3、随机数据读取,喂入训练数据的顺序也会影响结果。
……
应该还有别的随机值,这里不一一列出,这些随机都很容易影响网络的训练结果。

如果能够固定权重、固定数据增强情况、固定数据读取顺序,网络理论上每一次独立训练的结果都是一样的。

什么是随机种子

随机种子(Random Seed)是计算机专业术语。一般计算机的随机数都是伪随机数,以一个真随机数(种子)作为初始条件,然后用一定的算法不停迭代产生随机数。

按照这个理解,我们如果可以设置最初的 真随机数(种子),那么后面出现的随机数将会是固定序列。

以random库为例,我们使用如下的代码,前两次为随机生成,后两次为设置随机数生成器种子后生成。

import random# 生成随机整数
print("第一次随机生成")
print(random.randint(1,100))
print(random.randint(1,100))# 生成随机整数
print("第二次随机生成")
print(random.randint(1,100))
print(random.randint(1,100))# 设置随机数生成器种子
random.seed(11)# 生成随机整数
print("第一次设定种子后随机生成")
print(random.randint(1,100))
print(random.randint(1,100))# 重置随机数生成器种子
random.seed(11)# 生成随机整数
print("第二次设定种子后随机生成")
print(random.randint(1,100))
print(random.randint(1,100))

结果如下,前两次随机生成的序列不同,后两次设定种子后随机生成的序列相同:

第一次随机生成
66
37
第二次随机生成
93
56
第一次设定种子后随机生成
58
72
第二次设定种子后随机生成
58
72

训练中设置随机种子

一般训练会用到多个库包含有关random的内容。

在pytorch构建的网络中,一般都是使用下面三个库来获得随机数,我们需要对三个库都设置随机种子:
1、torch库;
2、numpy库;
3、random库。

在这里写了一个函数:

#---------------------------------------------------#
#   设置种子
#---------------------------------------------------#
def seed_everything(seed=11):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = False

这里面写到了cuda、cudnn这类gpu才会用到的东西,实测发现cpu版本的pytorch也可以正常运行。

torch.backends.cudnn.deterministic=True用于保证CUDA 卷积运算的结果确定。
torch.backends.cudnn.benchmark=False是用于保证数据变化的情况下,减少网络效率的变化。为True的话容易降低网络效率。

只需要在所有初始化前,调用该seed初始化函数即可。

另外,Pytorch一般使用Dataloader来加载数据,Dataloader一般会使用多worker加载多进程来加载数据,此时我们需要使用Dataloader自带的worker_init_fn函数初始化Dataloader启动的多进程,这样才能保证多进程数据加载时数据的确定性。

#---------------------------------------------------#
#   设置Dataloader的种子
#---------------------------------------------------#
def worker_init_fn(worker_id, rank, seed):worker_seed = rank + seedrandom.seed(worker_seed)np.random.seed(worker_seed)torch.manual_seed(worker_seed)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/15341.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Elasticsearch 8.X 聚合查询下的精度问题及其解决方案

1、线上环境问题 咕泡同学提问:我在看runtime文档的时候做个测试, agg求avg的时候不管是double还是long,数据都不准确,这种在生产环境中如何解决啊? 2、问题归类及出现场景 上述问题可以归类为:Elasticsear…

OpenCV 入门教程:Laplacian算子和Canny边缘检测

OpenCV 入门教程: Laplacian 算子和 Canny 边缘检测 导语一、Laplacian 算子二、Canny 边缘检测三、示例应用3.1 图像边缘检测3.2 边缘增强 总结 导语 边缘检测在图像处理和计算机视觉领域中起着重要的作用。 Laplacian 算子和 Canny 边缘检测是两种常用的边缘检测…

在tplink路由器xdr6088中运行Docker潘多拉(pengzhile/pandora)遇到无法访问的问题

在xdr6088中搜索安装pengzhile/pandora一切正常,但是按照常规运行docker后,直接访问8899端口无法打开页面,进入终端 运行如下命令 /usr/local/bin/python /usr/local/bin/pandora-cloud -s 0.0.0.0:8899 即可成功运行,然后客户端…

面试交流会

面试交流 目录: 01:关于人生目标 02:关于心态 03:关于选择 04:关于学习 05:关于简历 06:关于找工作 1:关于人生目标 1.01、自己想成为什么样的人? 1.02、自己的人生目标是…

如何用https协议支持小程序

步骤一:下载SSL证书 登录数字证书管理服务控制台。在左侧导航栏,单击SSL 证书。在SSL证书页面,定位到目标证书,在操作列,单击下载。 在服务器类型为Nginx的操作列,单击下载。 解压缩已下载的SSL证书压缩…

kubernetes环境搭建及部署

一、kubernetes 概述 1、kubernetes 基本介绍 kubernetes,简称 K8s,是用 8 代替 8 个字符“ubernete”而成的缩写。是一个开源 的,用于管理云平台中多个主机上的容器化的应用,Kubernetes 的目标是让部署容器化的 应用简单并且高效…

flutter tabBar 的属性及自定义实现

flutter tabBar 的属性及自定义实现 前言一、TabBar是什么?二、TabBar 自定义三、 Tab 自定义总结 前言 在Flutter中,TabBar的indicatorPadding属性用于设置指示器的内边距,而不是用于调整指示器和文字之间的间距。要调整TabBar中指示器和文字…

C#学习之路-常量

C# 常量 常量是固定值,程序执行期间不会改变。常量可以是任何基本数据类型,比如整数常量、浮点常量、字符常量或者字符串常量,还有枚举常量。 常量可以被当作常规的变量,只是它们的值在定义后不能被修改。 整数常量 整数常量可…

【密码学篇】GM密码行业标准下载方法

【密码学篇】GM密码行业标准下载方法 截止到2023年07月08日,密码行业标准化技术委员会共发布了144个密码行业标准,可点击链接预览或使用IDM下载器下载标准,此外该方法很多场景都适用,自行尝试—【蘇小沐】 文章目录 【密码学篇】…

【Linux】文件描述符 (上篇)

文章目录 📖 前言1. 文件的预备知识2. 复习C语言的文件操作3. Linux系统级文件接口3.1 open、 close、 read、 write 接口:3.2 内核当中实现的映射关系:3.3 如何理解Linux下一切皆文件: 📖 前言 本章开始,…

LVS负载均衡集群 keepalived

目录 1.实现方法 1.故障自动切换 (failover) 2.节点健康状态检查 (health checking) 2.实现LVS负载调度器 节点服务器的高可用(HA) 3.keepalived高可用故障切换原理 4.三个主要模块 5.案例 1.实现方法 1.故障自动切换 (failover) 主…

【Linux | Shell】Linux 安全系统 —— 用户、组、文件权限 - 阅读笔记

目录 一、Linux 的安全性1.1 /etc/passwd 文件1.2 /etc/shadow 文件1.3 添加新用户 —— useradd1.4 删除用户 —— userdel1.5 修改用户 —— usermod、passwd、chpasswd 二、使用 Linux 组2.1 /etc/group 文件2.2 创建新组 —— groupadd2.3 修改组 —— groupmod 三、理解文…