李沐动手学习深度学习——4.2练习

1. 在所有其他参数保持不变的情况下,更改超参数num_hiddens的值,并查看此超参数的变化对结果有何影响。确定此超参数的最佳值。

通过改变隐藏层的数量,导致就是函数拟合复杂度下降,隐藏层过多可能导致过拟合,而过少导致欠拟合。
我们将层数改为128可得:
在这里插入图片描述

2. 尝试添加更多的隐藏层,并查看它对结果有何影响。

过拟合,导致测试机精确度下降。

3. 改变学习速率会如何影响结果?保持模型架构和其他超参数(包括轮数)不变,学习率设置为多少会带来最好的结果?

过高的学习率导致,梯度跨度过大,使得降低不到对应的驻点。
过低的学习率导致训练缓慢,需要增加epoch。
在训练轮数不变的情况下,我们可以通过for 设置不同的学习率找出最合适的学习率。一般来说设置为0.01或者0.1足以

4. 通过对所有超参数(学习率、轮数、隐藏层数、每层的隐藏单元数)进行联合优化,可以得到的最佳结果是什么?

跑了一次学习率lr=0.01的情况:
在这里插入图片描述

需要大量的训练,但是目前我训练结果是学习率lr=0.1、轮数是num_epochs=10,隐藏层数为1,隐藏层数单元num_hiddens=128。

5. 描述为什么涉及多个超参数更具挑战性。

因为组合的情况更多,当层数越多时,训练时间也更多,这玩意就是炼丹了,看你自己的GPU还有时间、运气。

6. 如果想要构建多个超参数的搜索方法,请想出一个聪明的策略。

套用for 循环暴力破解,时间上肯定慢的要死,我们可以先固定其他变量,挑选一个变量寻找最优解,以此类推对所有的超参数这样使用,但是这种做法肯定不是最优的,只是能够较好的找出比较好的超参数。

由于学校穷逼所以没有闲置GPU服务器,所有的模型只能在colab上进行运行,其中遇到了d2l的版本对应问题,所以对于d2l.train_ch3跑不起来,只能使用自写进行替代如下:

import torch.nn
from d2l import torch as d2l
from IPython import displayclass Accumulator:"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * n       # 创建一个长度为 n 的列表,初始化所有元素为0.0。def add(self, *args):           # 累加self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):                # 重置累加器的状态,将所有元素重置为0.0self.data = [0.0] * len(self.data)def __getitem__(self, idx):     # 获取所有数据return self.data[idx]def accuracy(y_hat, y):"""计算正确的数量:param y_hat::param y::return:"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)            # 在每行中找到最大值的索引,以确定每个样本的预测类别cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def evaluate_accuracy(net, data_iter):"""计算指定数据集的精度:param net::param data_iter::return:"""if isinstance(net, torch.nn.Module):net.eval()                  # 通常会关闭一些在训练时启用的行为metric = Accumulator(2)with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]class Animator:"""在动画中绘制数据"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):# 增量的绘制多条线if legend is None:legend = []d2l.use_svg_display()self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函数捕获参数self.config_axes = lambda: d2l.set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):"""向图表中添加多个数据点:param x::param y::return:"""if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()display.display(self.fig)display.clear_output(wait=True)def train_epoch_ch3(net, train_iter, loss, updater):"""训练模型一轮:param net:是要训练的神经网络模型:param train_iter:是训练数据的数据迭代器,用于遍历训练数据集:param loss:是用于计算损失的损失函数:param updater:是用于更新模型参数的优化器:return:"""if isinstance(net, torch.nn.Module):  # 用于检查一个对象是否属于指定的类(或类的子类)或数据类型。net.train()# 训练损失总和, 训练准确总和, 样本数metric = Accumulator(3)for X, y in train_iter:  # 计算梯度并更新参数y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):  # 用于检查一个对象是否属于指定的类(或类的子类)或数据类型。# 使用pytorch内置的优化器和损失函数updater.zero_grad()l.mean().backward()  # 方法用于计算损失的平均值updater.step()else:# 使用定制(自定义)的优化器和损失函数l.sum().backward()updater(X.shape())metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):"""训练模型():param net::param train_iter::param test_iter::param loss::param num_epochs::param updater::return:"""animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):trans_metrics = train_epoch_ch3(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)animator.add(epoch + 1, trans_metrics + (test_acc,))train_loss, train_acc = trans_metricsprint(trans_metrics)def predict_ch3(net, test_iter, n=6):"""进行预测:param net::param test_iter::param n::return:"""global X, yfor X, y in test_iter:breaktrues = d2l.get_fashion_mnist_labels(y)preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true + "\n" + pred for true, pred in zip(trues, preds)]d2l.show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])d2l.plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/505766.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java项目:31 基于SSM的勤工俭学管理系统

作者主页:源码空间codegym 简介:Java领域优质创作者、Java项目、学习资料、技术互助 文中获取源码 项目介绍 勤工助学系统有管理员,部门管理员,用户三个角色。 管理员功能有个人中心。管理员管理,部门管理员管理&…

鸿蒙Harmony应用开发—ArkTS声明式开发(通用属性:Popup控制)

给组件绑定popup弹窗,并设置弹窗内容,交互逻辑和显示状态。 说明: 从API Version 7开始支持。后续版本如有新增内容,则采用上角标单独标记该内容的起始版本。 popup弹窗的显示状态在onStateChange事件回调中反馈,其显…

初阶数据结构:二叉树

目录 1. 树的相关概念1.1 简述:树1.2 树的概念补充 2. 二叉树2.1 二叉树的概念2.2 二叉树的性质2.3 二叉树的存储结构与堆2.3.1 存储结构2.3.2 堆的概念2.3.3 堆的实现2.3.3.1 堆的向上调整法2.3.3.2 堆的向下调整算法2.3.3.3 堆的实现 1. 树的相关概念 1.1 简述&a…

MyBatis 学习(六)之动态 SQL

目录 1 动态 SQL 介绍 2 if 标签 3 where 标签 4 set 标签 5 trim 标签 6 choose、when、otherwise 标签 7 foreach 标签 8 bind 标签 1 动态 SQL 介绍 动态 SQL 是 MyBatis 强大特性之一,极大的简化我们拼装 SQL 的操作。MyBatis 的动态 SQL 是基于 OGNL 的…

网络编程(IP、端口、协议、UDP、TCP)【详解】

目录 1.什么是网络编程? 2.基本的通信架构 3.网络通信三要素 4.UDP通信-快速入门 5.UDP通信-多发多收 6.TCP通信-快速入门 7.TCP通信-多发多收 8.TCP通信-同时接收多个客户端 9.TCP通信-综合案例 1.什么是网络编程? 网络编程是可以让设…

【Vue3】PostCss 适配

px 固定的单位,不会进行自适应。rem r root font-size16px 1rem16px,但是需要手动进行单位的换算vw vh 相对于视口的尺寸,不同于百分比(相对于父元素的尺寸)375屏幕 1vw 3.75px 利用插件进行 px(设计稿&…

SpringCloudAlibaba介绍

Spring Cloud Alibaba Spring Cloud Alibaba 是什么?微服务全景图核心特色 大家好,我叫阿明。下面我会为大家准备Spring Cloud Alibaba系列知识体系,结合实战输出案列,让大家一眼就能明白得技术原理,应用于各公司得各…

Github项目推荐-LightMirrors

项目地址 https://github.com/NoCLin/LightMirrors 项目简述 “LightMirrors是一个开源的缓存镜像站服务,用于加速软件包下载和镜像拉取。目前支持DockerHub、PyPI、PyTorch、NPM等镜像缓存服务。 当前项目仍处于早期阶段。”–来自项目说明。 也就是说&#xff…

如何解决线程安全问题(synchronized、原子性、产生线程不安全的原因,锁的特性,加锁的方式等等干货)

文章目录 💐线程不安全的示例💐锁的特性💐产生线程不安全的原因:💐加锁的三种方式 💐线程不安全的示例 对于线程安全问题,这里用一个例子进行讲解👇: 我现在定义一个变…

CTFHUB 命令执行

命令执行 原理: 在编写程序的时候,当碰到要执行系统命令来获取一些信息时,就要调用外部命令的函数,比如php中的exec()、system()等,如果这些函数的参数是由用户所提供的,那么恶意用户就可能通过构造命令拼…

MogaNet实战:使用MogaNet实现图像分类任务(一)

文章目录 摘要安装包安装timm 数据增强Cutout和MixupEMA项目结构计算mean和std生成数据集 摘要 论文:https://arxiv.org/pdf/2211.03295.pdf 作者多阶博弈论交互这一全新视角探索了现代卷积神经网络的表示能力。这种交互反映了不同尺度上下文中变量间的相互作用效…

微信小程序云开发教程——墨刀原型工具入门(编辑页面)

引言 作为一个小白,小北要怎么在短时间内快速学会微信小程序原型设计? “时间紧,任务重”,这意味着学习时必须把握微信小程序原型设计中的重点、难点,而非面面俱到。 要在短时间内理解、掌握一个工具的使用&#xf…