[oneAPI] 手写数字识别-LSTM

[oneAPI] 手写数字识别-LSTM

  • 手写数字识别
    • 参数与包
    • 加载数据
    • 模型
    • 训练过程
    • 结果
  • oneAPI

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

手写数字识别

使用了pytorch以及Intel® Optimization for PyTorch,通过优化扩展了 PyTorch,使英特尔硬件的性能进一步提升,让手写数字识别问题更加的快速高效
在这里插入图片描述

使用MNIST数据集,该数据集包含了一系列以黑白图像表示的手写数字,每个图像的大小为28x28像素,数据集组成如下:

  • 训练集:包含60,000个图像和标签,用于训练模型。
  • 测试集:包含10,000个图像和标签,用于测试模型的性能。

每个图像都被标记为0到9之间的一个数字,表示图像中显示的手写数字。这个数据集常常被用来验证图像分类模型的性能,特别是在计算机视觉领域。

参数与包

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transformsimport intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.01

加载数据

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',train=True,transform=transforms.ToTensor(),download=True)test_dataset = torchvision.datasets.MNIST(root='../../data/',train=False,transform=transforms.ToTensor())# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)

模型

# Recurrent neural network (many-to-one)
class RNN(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, num_classes)def forward(self, x):# Set initial hidden and cell states h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)# Forward propagate LSTMout, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)# Decode the hidden state of the last time stepout = self.fc(out[:, -1, :])return out

训练过程

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))# Test the model
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

结果

在这里插入图片描述

oneAPI

import intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# 模型
model = ConvNet(num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/70597.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在本地搭建WAMP服务器并通过端口实现局域网访问(无需公网IP)

文章目录 前言1.Wamp服务器搭建1.1 Wamp下载和安装1.2 Wamp网页测试 2. Cpolar内网穿透的安装和注册2.1 本地网页发布2.2 Cpolar云端设置2.3 Cpolar本地设置 3. 公网访问测试4. 结语 前言 软件技术的发展日新月异,各种能方便我们生活、工作和娱乐的新软件层出不穷&a…

第57步 深度学习图像识别:CNN可视化(Pytorch)

基于WIN10的64位系统演示 一、写在前面 由于不少模型使用的是Pytorch,因此这一期补上基于Pytorch实现CNN可视化的教程和代码,以SqueezeNet模型为例。 二、CNN可视化实战 继续使用胸片的数据集:肺结核病人和健康人的胸片的识别。其中&…

考研算法第46天: 字符串转换整数 【字符串,模拟】

题目前置知识 c中的string判空 string Count; Count.empty(); //正确 Count ! null; //错误c中最大最小宏 #include <limits.h>INT_MAX INT_MIN 字符串使用发运算将字符加到字符串末尾 string Count; string str "liuda"; Count str[i]; 题目概况 AC代码…

【数据分析入门】Numpy进阶

目录 一、数据重塑1.1 透视1.2 透视表1.3 堆栈/反堆栈1.3 融合 二、迭代三、高级索引3.1 基础选择3.2 通过isin选择3.3 通过Where选择3.4 通过Query选择3.5 设置/取消索引3.6 重置索引3.6.1 前向填充3.6.2 后向填充 3.7 多重索引 四、重复数据五、数据分组5.1 聚合5.2 转换 六、…

腾讯云GPU服务器GN7实例NVIDIA T4 GPU卡

腾讯云GPU服务器GN7实例搭载1颗 NVIDIA T4 GPU&#xff0c;8核32G配置&#xff0c;系统盘为100G 高性能云硬盘&#xff0c;自带5M公网带宽&#xff0c;系统镜像可选Linux和Windows&#xff0c;地域可选广州/上海/北京/新加坡/南京/重庆/成都/首尔/中国香港/德国/东京/曼谷/硅谷…

机器学习---对数几率回归

1. 逻辑回归 逻辑回归&#xff08;Logistic Regression&#xff09;的模型是一个非线性模型&#xff0c; sigmoid函数&#xff0c;又称逻辑回归函数。但是它本质上又是一个线性回归模型&#xff0c;因为除去sigmoid映射函 数关系&#xff0c;其他的步骤&#xff0c;算法都是…

JVM——HotSpot的算法细节实现

一、根节点枚举 固定可作为GC Roots的节点主要在全局性的引用&#xff08;如常量或类静态属性&#xff09;与执行上下文&#xff08;如栈帧中的本地变量表&#xff09;中&#xff0c;尽管目标明确&#xff0c;但查找要做到高效很难。现在java应用越来越庞大&#xff0c;光方法区…

Mr. Cappuccino的第60杯咖啡——Spring之BeanFactory和ApplicationContext

Spring之BeanFactory和ApplicationContext 类图BeanFactory概述功能项目结构项目代码运行结果总结 ApplicationContext概述功能MessageSource&#xff08;国际化的支持&#xff09;概述项目结构项目代码运行结果 ResourcePatternResolver&#xff08;匹配资源路径&#xff09;概…

从源代码编译构建Hive3.1.3

从源代码编译构建Hive3.1.3 编译说明编译Hive3.1.3更改Maven配置下载源码修改项目pom.xml修改hive源码修改说明修改standalone-metastore模块修改ql模块修改spark-client模块修改druid-handler模块修改llap-server模块修改llap-tez模块修改llap-common模块 编译打包异常集合异常…

ViewUI表格Table嵌套From表单-动态校验数据合法性的解决方法

项目场景&#xff1a; 项目需求&#xff1a;在表格中实现动态加减数据&#xff0c;并且每行表格内的输入框&#xff0c;都要动态校验数据&#xff0c;校验不通过&#xff0c;不让提交数据&#xff0c;并且由于表格内部空间较小&#xff0c;我仅保留红边框提示&#xff0c;文字…

【Java转Go】快速上手学习笔记(二)之基础篇一

目录 创建项目数据类型变量常量类型转换计数器键盘交互流程控制代码运算符 创建项目 上篇我们安装好了Go环境&#xff0c;和用IDEA安装Go插件来开发Go项目&#xff1a;【Java转Go】快速上手学习笔记&#xff08;一&#xff09;之环境安装篇 。 这篇我们开始正式学习Go语言。我…

工作流自动化:提升效率、节约成本的重要工具

在现代社会中&#xff0c;软件和技术的运用使得我们的日常活动变得更加简单和高效。然而&#xff0c;这些技术也有自身的特点和独特之处。尽管我们使用这些工具来简化工作&#xff0c;但有时仍需要一些人工干预&#xff0c;比如手动数据录入。在工作场所中&#xff0c;手动数据…