LeNet-5(fashion-mnist)

文章目录

  • 前言
  • LeNet
  • 模型训练

前言

LeNet是最早发布的卷积神经网络之一。该模型被提出用于识别图像中的手写数字。

LeNet

LeNet-5由以下两个部分组成

  • 卷积编码器(2)
  • 全连接层(3)
    卷积块由一个卷积层、一个sigmoid激活函数和一个平均汇聚层组成。
    第一个卷积层有6个输出通道,第二个卷积层有16个输出通道。采用2×2的汇聚操作,且步幅为2.
    3个全连接层分别有120,84,10个输出。
    此处对原始模型做出部分修改,去除最后一层的高斯激活。
net=nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(16*5*5,120),nn.Sigmoid(),nn.Linear(120,84),nn.Sigmoid(),nn.Linear(84,10))

模型训练

为了加快训练,使用GPU计算测试集上的精度以及训练过程中的计算。
此处采用xavier初始化模型参数以及交叉熵损失函数和小批量梯度下降。

batch_size=256
train_iter,test_iter=data_iter.load_data_fashion_mnist(batch_size)

将数据送入GPU进行计算测试集准确率

def evaluate_accuracy_gpu(net,data_iter,device=None):"""使用GPU计算模型在数据集上的精度"""if isinstance(net,torch.nn.Module):net.eval()if not device:device=next(iter(net.parameters())).device# 正确预测的数量,预测的总数eva = 0.0y_num = 0.0with torch.no_grad():for X,y in data_iter:if isinstance(X,list):X=[x.to(device) for x in X]else:X=X.to(device)y=y.to(device)eva += accuracy(net(X), y)y_num += y.numel()return eva/y_num

训练过程同样将数据送入GPU计算

def train_epoch_gpu(net, train_iter, loss, updater,device):# 训练损失之和,训练准确数之和,样本数train_loss_sum = 0.0train_acc_sum = 0.0num_samples = 0.0# timer = d2l.torch.Timer()for i, (X, y) in enumerate(train_iter):# timer.start()updater.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()updater.step()with torch.no_grad():train_loss_sum += l * X.shape[0]train_acc_sum += evaluation.accuracy(y_hat, y)num_samples += X.shape[0]# timer.stop()return train_loss_sum/num_samples,train_acc_sum/num_samplesdef train_gpu(net,train_iter,test_iter,num_epochs,lr,device):def init_weights(m):if type(m)==torch.nn.Linear or type(m)==torch.nn.Conv2d:torch.nn.init.xavier_uniform_(m.weight)net.apply(init_weights)net.to(device)print('training on',device)optimizer=torch.optim.SGD(net.parameters(),lr=lr)loss=torch.nn.CrossEntropyLoss()# num_batches=len(train_iter)tr_l=[]tr_a=[]te_a=[]for epoch in range(num_epochs):net.train()train_metric=train_epoch_gpu(net,train_iter,loss,optimizer,device)test_accuracy = evaluation.evaluate_accuracy_gpu(net, test_iter)train_loss, train_acc = train_metrictrain_loss = train_loss.cpu().detach().numpy()tr_l.append(train_loss)tr_a.append(train_acc)te_a.append(test_accuracy)print(f'epoch: {epoch + 1}, train_loss: {train_loss}, train_acc: {train_acc}, test_acc:{test_accuracy}')x = torch.arange(num_epochs)plt.plot((x + 1), tr_l, '-', label='train_loss')plt.plot(x + 1, tr_a, '--', label='train_acc')plt.plot(x + 1, te_a, '-.', label='test_acc')plt.legend()plt.show()print(f'on {str(device)}')
lr,num_epochs=0.9,10
Train.train_gpu(net,train_iter,test_iter,num_epochs,lr,device='cuda')

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/338157.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++day3作业

完善对话框,点击登录对话框,如果账号和密码匹配,则弹出信息对话框,给出提示”登录成功“,提供一个Ok按钮,用户点击Ok后,关闭登录界面,跳转到其他界面 如果账号和密码不匹配&#xf…

【普中开发板】基于51单片机的简易密码锁设计( proteus仿真+程序+设计报告+讲解视频)

基于51单片机的简易密码锁设计 1.主要功能:资料下载链接: 实物图:2.仿真3. 程序代码4. 设计报告5. 设计资料内容清单 【普中】基于51单片机的简易密码锁设计 ( proteus仿真程序设计报告讲解视频) 仿真图proteus8.16(有低版本) 程…

C++_vector类

目录 一、vector的模拟实现 1、vector的组成结构 2、vector尾插数据 2.1 析构函数 3、迭代器实现 4、resize 5、删除数据 5.1 迭代器失效 6、指定位置插入数据 6.1 迭代器失效 7、迭代器构造和resize构造 8、深浅拷贝 结语: 前言: vect…

Beauty algorithm(七)瘦脸

瘦脸的实现采用局部平移法。 一、skills 前瞻 局部平移 二、目标区域定位 左脸: 关键点选择3、5点,基点30 rmax:计算两点5-3间的距离, |x-c|:图像任一点到固定基点c的距离 |m-c|:两固定点距离 右脸: 关键点选择

空间转录组与单细胞转录组联合分析——MIA,代码分享(Nature Biotechnology :)

​ 原文:Integrating microarray-based spatial transcriptomics and single-cell RNA-seq reveals tissue architecture in pancreatic ductal adenocarcinomas | Nature Biotechnology 研究者采用 MIA 联合 scRNAseq 和 ST 数据,分析原发性胰腺导管腺癌…

ThreadLocal如何使用详解

ThreadLocal概述: ThreadLocal是Java中的一个线程局部变量工具类,它提供了一种在多线程环境下,每个线程都可以独立访问自己的变量副本的机制。ThreadLocal中存储的数据对于每个线程来说都是独立的,互不干扰。 使用场景&#xff1a…

编译ZLMediaKit(win10+msvc2019_x64)

前言 因工作需要,需要ZLMediaKit,为方便抓包分析,最好在windows系统上测试,但使用自己编译的第三方库一直出问题,无法编译通过。本文档记录下win10上的编译过程,供有需要的小伙伴使用 一、需要安装的软件…

快来预约出行!浙江宁海“网约巴士”走出新思路,跑出市民新满意度

在浙江宁海,一辆辆金旅星辰网约巴士亮眼吸睛,正在成为市民出行的全新方式。 它以灵动、萌态十足的小巴士为载体,依托大数据平台及互联网出行创新理念,将传统公交与动态拼车相结合,从而塑造出更灵活、更便宜、更便捷、更…

单片机原理及应用:中断嵌套

​中断嵌套是指中断系统正在执行一个中断服务时,有另一个优先级更高的中断提出中断请求,这时会暂时终止当前正在执行的级别较低的中断源的服务程序,去处理级别更高的中断源,待处理完毕,再返回到被中断了的中断服务程序…

【EI会议征稿通知】2024年第三届生物医学与智能系统国际学术会议(IC-BIS 2024)

2024年第三届生物医学与智能系统国际学术会议(IC-BIS 2024) 2024 3rd International Conference on Biomedical and Intelligent Systems (IC-BIS 2024) 2024年第三届生物医学与智能系统国际学术会议(IC-BIS 2024) 将于2024年4月…

Java项目:115SSM宿舍管理系统

博主主页:Java旅途 简介:分享计算机知识、学习路线、系统源码及教程 文末获取源码 一、项目介绍 宿舍管理系统基于SpringSpringMVCMybatis开发,系统主要功能如下: 学生管理班级管理宿舍管理卫生管理维修登记访客管理 二、技术框…

test fuzz-07-模糊测试 libfuzzer

拓展阅读 开源 Auto generate mock data for java test.(便于 Java 测试自动生成对象信息) 开源 Junit performance rely on junit5 and jdk8.(java 性能测试框架。性能测试。压测。测试报告生成。) test fuzz-01-模糊测试(Fuzz Testing) test fuzz-…