[oneAPI] 手写数字识别-BiLSTM

[oneAPI] 手写数字识别-BiLSTM

  • 手写数字识别
    • 参数与包
    • 加载数据
    • 模型
    • 训练过程
    • 结果
  • oneAPI

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

手写数字识别

使用了pytorch以及Intel® Optimization for PyTorch,通过优化扩展了 PyTorch,使英特尔硬件的性能进一步提升,让手写数字识别问题更加的快速高效
在这里插入图片描述

使用MNIST数据集,该数据集包含了一系列以黑白图像表示的手写数字,每个图像的大小为28x28像素,数据集组成如下:

  • 训练集:包含60,000个图像和标签,用于训练模型。
  • 测试集:包含10,000个图像和标签,用于测试模型的性能。

每个图像都被标记为0到9之间的一个数字,表示图像中显示的手写数字。这个数据集常常被用来验证图像分类模型的性能,特别是在计算机视觉领域。

参数与包

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transformsimport intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003

加载数据

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',train=True,transform=transforms.ToTensor(),download=True)test_dataset = torchvision.datasets.MNIST(root='../../data/',train=False,transform=transforms.ToTensor())# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)

模型

# Bidirectional recurrent neural network (many-to-one)
class BiRNN(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(BiRNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)self.fc = nn.Linear(hidden_size * 2, num_classes)  # 2 for bidirectiondef forward(self, x):# Set initial statesh0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)  # 2 for bidirectionc0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)# Forward propagate LSTMout, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)# Decode the hidden state of the last time stepout = self.fc(out[:, -1, :])return out

训练过程

model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))# Test the model
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

结果

在这里插入图片描述

oneAPI

import intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# 模型
model = ConvNet(num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/69961.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

实习笔记(一)

自定义注解: 自定义注解中有三个元注解Target,Retention,Document /*** 系统日志注解** author Mark sunlightcsgmail.com*/ Target(ElementType.METHOD) Retention(RetentionPolicy.RUNTIME) Documented public interface SysLog {String value() default "…

nodejs+vue+elementui学生档案信息管理系统_06bg9

利用计算机网络的便利,开发一套基于nodejs的大学生信息管理系统,将会给人们的生活带来更多的便利,而且在经济效益上,也会有很大的便利!这可以节省大量的时间和金钱。学生信息管理系统是学校不可缺少的一个环节,其内容直…

Mac安装opencv后无法导入cv2的解决方法

前提条件:以下两个插件安装成功 pip install opencv-python pip install --user opencv-contrib-python 注:直接用pip install opencv-contrib-python如果报错,就加上“–user" 第一步: 设置–添加python解释器 第二步&am…

FOSSASIA Summit 2023 - 开源亚洲行

作者 Ted 致歉:本来这篇博客早就该发出,但是由于前几个月频繁差旅导致精神不佳,再加上后续我又参加了 Linux 基金会 7/27 在瑞士日内瓦举办的 Open Source Congress,以及 7/29-30 台北的 COSCUP23,干脆三篇连发&#x…

VScode替换cmd powershell为git bash 终端,并设置为默认

效果图 步骤 1. 解决VScode缺少git bash的问题_failed to start bash - is git-bash.exe on the syst_Rudon滨海渔村的博客-CSDN博客效果解决步骤找到git安装目录下的/bin/bash.exe,复制其绝对路径,例如D:\Program Files\Git\bin\bash.exe把路径的右斜…

mysql知识点+面试总结

目录 1 mysql介绍 2 数据库常见语法 3 数据库表的常见语法 4 其他常见语法(日期,查询表字段) 5 JDBC开发步骤 6 索引 6.1 索引常见语法 7 常见面试总结 8 java代码搭建监控页面 1 mysql介绍 数据库:存储在硬盘上的文件系统…

docker compose部署zookeeper

单机部署 新建docker-compose.yaml version: 3 services:zookeeper:image: zookeeper:3.5.7container_name: base-zookeeperhostname: zookeeperprivileged: truerestart: alwaysports:- 2181:2181environment:TZ: "Asia/Shanghai"volumes:- ./volumes/zookeeper/d…

无涯教程-Perl - symlink函数

描述 此函数在OLDFILE和NEWFILE之间创建符号链接。在不支持符号链接的系统上,会导致致命错误。 语法 以下是此函数的简单语法- symlink ( OLDFILE, NEWFILE )返回值 如果失败,此函数返回0,如果成功,则返回1。 例 以下是显示其基本用法的示例代码,首先在/tmp目录中创建一…

RS485、MODBUS通信协议详解

前言 MODBUS协议是Modicon公司发表的一种串行通信协议,属于OSI模型中应用层的协议,现广泛应用于工业控制领域,它的主要特点是免费开放、支持多种电气接口(如RS-232、RS-485),传输介质可以是双绞线、光纤、无…

微信小程序:模板使用

目录 模板的优点: 一、静态模板创建 二、静态模板使用 1.*.wxml引入模板 2.模板使用 3.*.wxss引入模板的样式 三、动态模板创建 四、动态模板使用 1.*.wxml引入模板 2.模板使用 3.*.js定义动态数据 五、结果展示 总结 模板的优点: 有利于保持网…

opencv进阶06-基于K邻近算法识别手写数字示例

opencv 中 K 近邻模块的基本使用说明及示例 在 OpenCV 中,不需要自己编写复杂的函数实现 K 近邻算法,直接调用其自带的模块函数即可。本节通过一个简单的例子介绍如何使用 OpenCV 自带的 K 近邻模块。 本例中有两组位于不同位置的用于训练的数据集&…

人工智能学习框架—飞桨Paddle人工智能

1.人工智能框架 机器学习的三要素:模型、学习策略、优化算法。 当我们用机器学习来解决一些模式识别任务时,一般的流程包含以下几个步骤: 1.1.浅层学习和深度学习 浅层学习(Shallow Learning):不涉及特征学习,其特征…