深度学习_10_softmax_实战

由于网上代码的画图功能是基于jupyter记事本,而我用的是pycham,这导致画图代码不兼容pycharm,所以删去部分代码,以便能更好的在pycharm上运行

完整代码:

import torch
from d2l import torch as d2l"创建训练集&创建检测集合"
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)"创建模型w, b"
num_inputs = 784
num_outputs = 10W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)"softmax"
def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partition  # 这里应用了广播机制"输出,即传入图片输出"
def net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)"交叉熵损失"
def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])"显示预测与估计相对应下标数量"
def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: # 确定长宽高都大于1y_hat = y_hat.argmax(axis=1) # 取出每行中最大值cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum()) # 返回对应下标数量"利用优化后的模型计算精度"
def evaluate_accuracy(net, data_iter):  #@saveif isinstance(net, torch.nn.Module):net.eval()  # 将模型设置为评估模式metric = Accumulator(2)  # 正确预测数、预测总数with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel()) # 下标相同数量 / 总下标return metric[0] / metric[1]"加法器"
class Accumulator:  #@savedef __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]"训练更新模型&返回训练损失与精度函数"
def train_epoch_ch3(net, train_iter, loss, updater):  #@save"""训练模型一个迭代周期(定义见第3章)"""# 将模型设置为训练模式if isinstance(net, torch.nn.Module):net.train()# 训练损失总和、训练准确度总和、样本数metric = Accumulator(3)for X, y in train_iter:# 计算梯度并更新参数y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):# 使用PyTorch内置的优化器和损失函数updater.zero_grad()l.mean().backward()updater.step()else:# 使用定制的优化器和损失函数l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]lr = 0.1"更新模型"
def updater(batch_size):return d2l.sgd([W, b], lr, batch_size)if __name__ == '__main__':num_epochs = 10cnt = 1for i in range(num_epochs):X, Y = train_epoch_ch3(net, train_iter, cross_entropy, updater)print("训练次数: " + str(cnt))cnt += 1print("训练损失: {:.4f}".format(X))print("训练精度: {:.4f}".format(Y))print(".................................")
#        print(W)
#        print(b)

效果:

在这里插入图片描述

训练效果还是和网上一样的,就是缺了画图功能,将就着吧

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/168514.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Rocky Linux 配置邮件发送

Rocky Linux 配置邮件发送 使用自己的有邮箱发送 第一步-开启STMP授权 首先要开启STMP授权码,以QQ邮箱为例 第二步-下载安装包 说明一点不用命令行安装也可以,在命令行中输入会提示你是否安装s-nail,一直y即可 mail下载必须要的安装包 …

数据结构与算法C语言版学习笔记(6)-树、二叉树、赫夫曼树

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、树的定义1.结点的度、树的度2.结点的逻辑关系3.树的深度4.有序树和无序树5.森林 二、树的存储结构(1)双亲表示法(2&…

leetcode(力扣) 207. 课程表1+2(图的构造与遍历,清晰思路,完整模拟)

文章目录 题目描述思路分析完整代码 题目描述 你这个学期必须选修 numCourses 门课程,记为 0 到 numCourses - 1 。 在选修某些课程之前需要一些先修课程。 先修课程按数组 prerequisites 给出,其中 prerequisites[i] [ai, bi] ,表示如果要学…

矢量绘图软件Sketch 99 for mac

Sketch是一款为用户提供设计和创建数字界面的矢量编辑工具。它主要用于UI/UX设计师、产品经理和开发人员,帮助他们快速设计和原型各种应用程序和网站。 Sketch具有简洁直观的界面,以及丰富的功能集,使得用户可以轻松地创建、编辑和共享精美的…

openai自定义API操作 API 返回值说明

custom-自定义API操作 openai.custom 公共参数 名称类型必须描述keyString是调用key(获取测试key)secretString是调用密钥api_nameString是API接口名称(包括在请求地址中)[item_search,item_get,item_search_shop等]cacheStrin…

轻量日志管理方案-[EFK]

使用FileBeat进行日志文件的数据收集,并发送到ES进行存储,最后Kibana进行查看展示; 这个应该是最简单,轻量的日志收集方案了。 最总方案为:FileBeatESKibana ; 【Kibana过于强大,感觉可以无限扩展】 文章目…

Apinto 网关进阶教程,使用 API Mock 生成模拟数据

什么是 API Mock ? API Mock 是一种技术,它允许程序员在不依赖后端数据的情况下,模拟 web服务器端 API 的响应。通常使用 API Mock 来测试前端应用程序,而无需等待后端程序构建完成。API Mock 可以模拟任何 HTTP 请求方法&#x…

linux安装nodejs

写在前面 因为工作需要,需要使用到nodejs,所以这里简单记录下学习过程。 1:安装 wget https://nodejs.org/dist/v14.17.4/node-v14.17.4-linux-x64.tar.xz tar xf node-v14.17.4-linux-x64.tar.xz mkdir /usr/local/lib/node // 这一步骤根…

使用电脑时提示msvcp140.dll丢失的5个解决方法

“计算机中msvcp140.dll丢失的5个解决方法”。在我们日常使用电脑的过程中,有时会遇到一些错误提示,其中之一就是“msvcp140.dll丢失”。那么,什么是msvcp140.dll呢?它的作用是什么?丢失它会对电脑产生什么影响呢&…

人工智能基础_机器学习022_使用正则化_曼哈顿距离_欧氏距离_提高模型鲁棒性_过拟合_欠拟合_正则化提高模型泛化能力---人工智能工作笔记0062

然后我们再来看一下,过拟合和欠拟合,现在,实际上欠拟合,出现的情况已经不多了,欠拟合是 在训练集和测试集的准确率不高,学习不到位的情况. 然后现在一般碰到的是过拟合,可以看到第二个就是,完全就把红点蓝点分开了,这种情况是不好的, 因为分开是对训练数据进行分开的,如果来…

如何将NetCore Web程序独立发布部署到Linux服务器

简介 在将 .NET Core 应用程序部署到 Linux 服务器上时,可以采用独立发布的方式,以便在目标服务器上运行应用程序而无需安装 .NET Core 运行时。本文介绍如果将NetCore Web程序独立发布部署到Linux服务器。 1、准备一台服务器 服务器配置:2核2G 系统环境:Alibaba Cloud…

可视化技术专栏100例教程导航帖—学习可视化技术的指南宝典

🎉🎊🎉 你的技术旅程将在这里启航! 🚀🚀 本文专栏:可视化技术专栏100例 可视化技术专栏100例领略各种先进的可视化技术,包括但不限于大屏可视化、图表可视化等等。订阅专栏用户在文章…