交叉熵Loss多分类问题实战(手写数字)

1、import所需要的torch库和包
在这里插入图片描述
2、加载mnist手写数字数据集,划分训练集和测试集,转化数据格式,batch_size设置为200在这里插入图片描述
3、定义三层线性网络参数w,b,设置求导信息
在这里插入图片描述
4、初始化参数,这一步比较关键,是否初始化影响到数据质量以及后续网络学习效果
在这里插入图片描述
5、自定义三层线性网络
在这里插入图片描述
6、选定优化器激活函数和loss函数
在这里插入图片描述
7、训练及测试,并记录每轮训练的loss变化和在测试集上的效果。第一轮就达到了98的准确度,判断是初始化效果较好,在前几次测试中根据初始化的情况不同,初始准确率为50%-85%不等
在这里插入图片描述
完整代码:

import torch
import torchvision
import torch.nn.functional as Ftrain_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307, ), (0.3081, ))])),batch_size=200, shuffle=True)test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307, ), (0.3081, ))])),batch_size=200, shuffle=True)w1 = torch.randn(200, 784, requires_grad=True)
b1 = torch.randn(200, requires_grad=True)
w2 = torch.randn(200, 200, requires_grad=True)
b2 = torch.randn(200, requires_grad=True)
w3 = torch.randn(10, 200, requires_grad=True)
b3 = torch.randn(10, requires_grad=True)torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)def forward(x):x = x@w1.t() +b1x = F.relu(x)x = x@w2.t() +b2x = F.relu(x)x = x@w3.t() +b3x = F.relu(x)return xoptimizer = torch.optim.Adam([w1, b1, w2, b2, w3, b3], lr=0.001)
criterion = torch.nn.CrossEntropyLoss()for epoch in range(10):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)logits = forward(data)loss = criterion(logits, target)optimizer.zero_grad()loss.backward()optimizer.step()if (batch_idx+1) % 150 == 0:print('Train Epoch:{} [{}/{}({:.0f}%)]\tLoss:{:.6f}'.format(epoch, (batch_idx+1) * len(data), len(train_loader.dataset),100. * (batch_idx+1) / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28*28)logits = forward(data)test_loss += criterion(logits, target).item()pred = logits.data.max(1)[1]correct += pred.eq(target.data).sum()test_loss /= len(test_loader)print('\nTest Set:Average Loss:{:.4f}, Accuracy:{}/{}({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/133544.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

distcc分布式编译

distcc https://gitee.com/bison-fork/distcc.git 下载工具链 mingw,https://www.mingw-w64.org/downloads/#w64devkitperl,https://strawberryperl.com/releases.html免安装zip版本,autoconf等脚本依赖perlautoconf、automake&#xff0c…

使用wireshark解析ipsec esp包

Ipsec esp包就是ipsec通过ike协议协商好后建立的通信隧道使用的加密包,该加密包里面就是用户的数据,比如通过的语音等。 那么如何将抓出来的esp包解析出来看呢? 获取相关的esp的key信息. 打开wireshark -> edit->preferences 找到pr…

线程入门java

1:线程创建方式二 让子类继承Thead类 必须重写Thead类的run方法 写多态的写法 注意优缺点 线程已经Thread 无法继承其他类 package threadTest;public class ThreadTest1 {//目标:掌握线程的创建方式继承thread类public static void main(String[] …

美创科技三重数据安全韧性,杜绝删库跑路

从删库到跑路,教训很多,但类似事件近年来总在重复上演,有运维部为此连夜鏖战恢复,更有企业陷入“至暗时刻”,经济受损、名誉蒙尘。 组织单位应该采取怎样的策略和积极主动的方法,避免酿成严重的后果&#x…

TensorFlow入门(十九、softmax算法处理分类问题)

softmax是什么? Sigmoid、Tanh、ReLU等激活函数,输出值只有两种(0、1,或-1、1或0、x),而实际现实生活中往往需要对某一问题进行多种分类。例如之前识别图片中模糊手写数字的例子,这个时候就需要使用softmax算法。 softmax的算法逻辑 如果判断输入属于某一个类的概率大于属于其…

云原生Kubernetes:Rancher管理k8s集群

目录 一、理论 1.Rancher 2.Rancher 安装及配置 二、实验 1.Rancher 安装及配置 三、问题 1. Rancher 部署监控系统报错 四、总结 一、理论 1.Rancher (1) 概念 Rancher 简介 Rancher 是一个开源的企业级多集群 Kubernetes 管理平台,实现了 Kubernetes …

软件测试学习(二)静态白盒测试、动态白盒测试、配置测试、兼容性测试、外国语言测试

目录 静态白盒测试:检查设计和代码 正式审查 同事审查 走查 检验 编码标准和规范 通用代码审查清单 数据引用错误 数据声明错误 计算错误 比较错误 控制流程错误 子程序参数错误 输入/输出错误 其他检查 动态白盒测试:带上x光眼镜测试 …

主动配电网故障恢复的重构与孤岛划分matlab程序

微❤关注“电气仔推送”获得资料(专享优惠) 参考文档: A New Model for Resilient Distribution Systems by Microgrids Formation; 主动配电网故障恢复的重构与孤岛划分统一模型; 同时考虑孤岛与重构的配电网故障…

【Vue基础-数字大屏】加载动漫效果

一、需求描述 当网页正在加载而处于空白页面状态时,可以在该页面上显示加载动画提示。 二、步骤代码 1、全局下载npm install -g json-server npm install -g json-server 2、在src目录下新建文件夹mock,新建文件data.json存放模拟数据 {"one&…

linux查看文件内容命令more/less/cat/head/tail/grep

1.浏览全部内容more/less 文件: more:可以查看文件第一屏的内容,同时左下角有一个显示内容占全部文件内容的百分比,空格键会显示下一屏的内容,直到文件末尾 [rootmaster data]# more file1less:相较于mor…

1014蓝桥算法双周赛,学习算法技巧,助力蓝桥杯

家人们,我来免费给大家送福利了!!! 【1014蓝桥算法双周赛 】 背景 蓝桥杯全国软件和信息技术专业人才大赛是由工业和信息化部人才交流中心举办的全国性IT学科赛事。参赛高校超过1200余所,累计参赛人数超过40万人。该…

Xcode升级到15.0 解决DT_TOOLCHAIN_DIR问题

根据个人开发遇到的问题做的总结,公司要求Xcode 14.2 ,Swift 5.7开发,由于升级了Mac 14.0系统后,Xcode 14.2不能使用,解决方案目前有2个 一、在原来Xcode 14.2 的显示包内容,如图 二、升级到Xcode的15.0后…