CNN从搭建到部署实战(pytorch+libtorch)

模型搭建

下面的代码搭建了CNN的开山之作LeNet的网络结构。

import torchclass LeNet(torch.nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_sizetorch.nn.Sigmoid(),torch.nn.MaxPool2d(2, 2), # kernel_size, stridetorch.nn.Conv2d(6, 16, 5),torch.nn.Sigmoid(),torch.nn.MaxPool2d(2, 2))self.fc = torch.nn.Sequential(torch.nn.Linear(16*4*4, 120),torch.nn.Sigmoid(),torch.nn.Linear(120, 84),torch.nn.Sigmoid(),torch.nn.Linear(84, 10))def forward(self, img):feature = self.conv(img)flat = feature.view(img.shape[0], -1)output = self.fc(flat)return outputnet = LeNet()   
print(net)
print('parameters:', sum(param.numel() for param in net.parameters()))

运行代码,输出结果:

LeNet((conv): Sequential((0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))(1): Sigmoid()(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))(4): Sigmoid()(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(fc): Sequential((0): Linear(in_features=256, out_features=120, bias=True)(1): Sigmoid()(2): Linear(in_features=120, out_features=84, bias=True)(3): Sigmoid()(4): Linear(in_features=84, out_features=10, bias=True))
)
parameters: 44426

模型训练

编写训练代码如下:16-18行解析参数;20-21行加载网络;23-24行定义损失函数和优化器;26-36行定义数据集路径和数据变换,加载训练集和测试集(实际上应该是验证集);37-57行for循环中开始训练num_epochs轮次,计算训练集和测试集(验证集)上的精度,并保存权重。

import torch
import torchvision
import time
import argparse
from models.lenet import netdef parse_args():parser = argparse.ArgumentParser('training')parser.add_argument('--batch_size', default=128, type=int, help='batch size in training')parser.add_argument('--num_epochs', default=5, type=int, help='number of epoch in training')return parser.parse_args()if __name__ == '__main__':args = parse_args()batch_size = args.batch_sizenum_epochs = args.num_epochsdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')net = net.to(device)loss = torch.nn.CrossEntropyLoss()optimizer = torch.optim.Adam(net.parameters(), lr=0.001)train_path = r'./Datasets/mnist_png/training'test_path = r'./Datasets/mnist_png/testing'transform_list = [torchvision.transforms.Grayscale(num_output_channels=1), torchvision.transforms.ToTensor()]transform = torchvision.transforms.Compose(transform_list)train_dataset = torchvision.datasets.ImageFolder(train_path, transform=transform)test_dataset = torchvision.datasets.ImageFolder(test_path, transform=transform)train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)for epoch in range(num_epochs):train_l, train_acc, test_acc, m, n, batch_count, start = 0.0, 0.0, 0.0, 0, 0, 0, time.time()for X, y in train_iter:X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l += l.cpu().item()train_acc += (y_hat.argmax(dim=1) == y).sum().cpu().item()m += y.shape[0]batch_count += 1with torch.no_grad():for X, y in test_iter:net.eval() # 评估模式test_acc += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()net.train() # 改回训练模式n += y.shape[0]print('epoch %d, loss %.6f, train acc %.3f, test acc %.3f, time %.1fs'% (epoch, train_l / batch_count, train_acc / m, test_acc / n, time.time() - start))torch.save(net, "checkpoint.pth")

该代码支持cpu和gpu训练,损失函数是CrossEntropyLoss,优化器是Adam,数据集用的是手写数字mnist数据集。训练的部分打印日志如下:

epoch 0, loss 1.486503, train acc 0.506, test acc 0.884, time 25.8s
epoch 1, loss 0.312726, train acc 0.914, test acc 0.938, time 33.3s
epoch 2, loss 0.185561, train acc 0.946, test acc 0.960, time 27.4s
epoch 3, loss 0.135757, train acc 0.960, test acc 0.968, time 24.9s
epoch 4, loss 0.108427, train acc 0.968, test acc 0.972, time 19.0s

模型测试

测试的代码非常简单,流程是加载网络和权重,然后读入数据进行变换再前向推理即可。

import cv2
import torch
from pathlib import Path
from models.lenet import net
import torchvision.transforms.functionalif __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')net = net.to(device)net = torch.load('checkpoint.pth')net.eval()with torch.no_grad():imgs_path = Path(r"./Datasets/mnist_png/testing/0/").glob("*")for img_path in imgs_path:img = cv2.imread(str(img_path), 0)img_tensor = torchvision.transforms.functional.to_tensor(img)img_tensor = torch.unsqueeze(img_tensor, 0)print(net(img_tensor.to(device)).argmax(dim=1).item())

输出部分结果如下:

0
0
0
0
0
0
0
0

模型转换

下面的脚本提供了pytorch模型转换torchscript和onnx的功能。

import torch
from models.lenet import netif __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')net = net.to(device)net = torch.load('checkpoint.pth')x = torch.rand(1, 1, 28, 28)x = x.to(device)traced_script_module = torch.jit.trace(net, x)traced_script_module.save("checkpoint.pt")torch.onnx.export(net,x,"checkpoint.onnx",opset_version = 11)

onnx模型netron可视化结果如下:
在这里插入图片描述

模型部署

10.png图片
在这里插入图片描述

libtorch部署

#include <iostream>
#include <opencv2/opencv.hpp>
#include <torch/torch.h>
#include <torch/script.h> int main(int argc, char* argv[])
{std::string model = "checkpoint.pt";torch::jit::script::Module module = torch::jit::load(model);module.to(torch::kCUDA);cv::Mat image = cv::imread("10.png", 0);image.convertTo(image, CV_32F, 1.0 / 255);at::Tensor img_tensor = torch::from_blob(image.data, { 1, 1, image.rows, image.cols }, torch::kFloat32).to(torch::kCUDA);torch::Tensor result = module.forward({ img_tensor }).toTensor();std::cout << result << std::endl;std::cout << result.argmax(1) << std::endl;return 0;
}

输出结果:

 8.0872 -6.3622  0.0291 -1.7327 -4.0367  0.8192  0.8159 -3.2559 -1.8254 -2.2787
[ CUDAFloatType{1,10} ]0
[ CUDALongType{1} ]

opencv dnn部署

只用OpenCV也可以部署模型,其中的dnn模块可以解析onnx格式模型。

#include <iostream>
#include <opencv2/opencv.hpp>int main(int argc, char* argv[])
{std::string model = "checkpoint.onnx";cv::dnn::Net net = cv::dnn::readNet(model);cv::Mat image = cv::imread("10.png", 0), blob;cv::dnn::blobFromImage(image, blob, 1. / 255., cv::Size(28, 28), cv::Scalar(), true, false);net.setInput(blob);std::vector<cv::Mat> output;net.forward(output, net.getUnconnectedOutLayersNames());std::vector<float> values;for (size_t i = 0; i < output[0].cols; i++){values.push_back(output[0].at<float>(0, i));}std::cout << std::distance(values.begin(), std::max_element(values.begin(), values.end())) << std::endl;return 0;
}

输出结果:

0

本文的完整工程可见:https://github.com/taifyang/deep-learning-pytorch-demo

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/23097.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

windows搭建git服务器 无法识别 ‘git‘ 命令:exec: “git“: executable file not found in %PATH%

无法识别 git 命令&#xff1a;exec: "git": executable file not found in %PATH% 确保已经安装git&#xff0c;如下图配置环境变量即可。

Mysql查询

Mysql查询 一.DQL基础查询1.语法2.特点3.查询结果处理 二.单行函数(1)字符函数(2)逻辑处理(3)数学函数(4)日期函数 三.分组函数四.条件查询五.比较六.模糊查询七.UNION和UNION ALL(1)UNION(2)UNION ALL 八.排序九.数量限制十.分组查询 一.DQL基础查询 DQL&#xff08;Data Que…

奇葩功能实现:级联选择框组件el-cascader实现同一级的二级只能单选,但是一级可以多选

前言&#xff1a; 其实也不能说这个功能奇葩&#xff0c;做项目碰到这种需求也算合理正常&#xff0c;只是确实没有能直接实现这一需求的现成组件。 el-cascader作为级联选择组件&#xff0c;并不能同时支持一级多选&#xff0c;二级单选的功能&#xff0c;只能要么是单选或者…

SpringBoot 配置文件:什么是配置文件?配置文件是干什么?

文章目录 &#x1f387;前言1.配置文件的格式2. properties配置文件说明2.1 properties基本语法2.2 读取配置文件 3. yml 配置文件说明3.1 yml 基本语法 4.properties与yml 对比 &#x1f387;前言 学习一个东西&#xff0c;我们先要知道它有什么用处。整个项目中所有重要的数…

java单元测试(调试)

文章目录 测试分类JUnit单元测试介绍引入本地JUnit.jar编写和运行Test单元测试方法设置执行JUnit用例时支持控制台输入10.6.6 定义test测试方法模板 测试分类 **黑盒测试&#xff1a;**不需要写代码&#xff0c;给输入值&#xff0c;看程序是否能够输出期望的值。 **白盒测试…

【解决】Android Studio打包出现not found for signing config ‘externalOverride‘

问题出现场景 之前我的这个项目在另一台电脑上开发&#xff0c;现在迁移到这台计算机上&#xff0c;出现了key报错的问题&#xff0c;网络上有些说需要在XML中进行配置signature相关的内容&#xff0c;这个感觉比较复杂&#xff0c;本文主要介绍一个简单的解决方法&#xff0c;…

VectorCAST单元测试参数配置

一、打开 VectorCAST 通常情况下&#xff0c;技术人员会配置一个脚本文件&#xff08;.bat、.cmd&#xff09;&#xff0c;用户可以通过这个脚本文件来启动 VectorCAST。使用脚本文件启动 VectorCAST&#xff0c;可以在启动时设置好编译器相关的环境变量&#xff0c;方便 Vecto…

el-ment ui 表格组件table实现列的动态插入功能

在实际需求中我们经常遇到各种奇葩的需求&#xff0c;不足为奇。每个项目的需求各不相同&#xff0c;实现功能的思路大致是一样的。 本文来具体介绍怎么实现table表格动态插入几列。 首先实现思路有2种&#xff0c; 1. 插入的位置如果是已知的&#xff0c;我知道在哪个标题的…

redis -速成

目录 &#xff08;一&#xff09;认识 Redis 1.1数据库分类 1.2 什么是Redis 1.2.1 redis简介 1.2.2 谁在用Redis 1.2.3 怎么学redis 1.2.4 Redis的安装 2 数据类型 2.1 概况 2.2 String类型 2.2.1 常用的命令 2.2.2 非常用命令 2.2.3 举例 2.2.4应用场景&#xf…

ELk日志平台搭建

ELk日志平台搭建 一、ELK概述 1.ELK简介 ELK平台是一套完整的日志集中处理解决方案&#xff0c;将 ElasticSearch、Logstash 和 Kiabana 三个开源工具配合使用&#xff0c; 完成更强大的用户对日志的查询、排序、统计需求。2.组件 ●ElasticSearch&#xff1a;是基于Lucene…

qt和vue的交互

1、首先在vue项目中引入qwebchannel /******************************************************************************** Copyright (C) 2016 The Qt Company Ltd.** Copyright (C) 2016 Klarlvdalens Datakonsult AB, a KDAB Group company, infokdab.com, author Milian …

LangChain + ChatGLM2-6B 搭建个人专属知识库

之前教过大家利用 langchain ChatGLM-6B 实现个人专属知识库&#xff0c;非常简单易上手。最近&#xff0c;智谱 AI 研发团队又推出了 ChatGLM 系列的新模型 ChatGLM2-6B&#xff0c;是开源中英双语对话模型 ChatGLM-6B 的第二代版本&#xff0c;性能更强悍。 树先生之所以现…