视频与AI,与进程交互(二) pytorch 极简训练自己的数据集并识别

目标学习任务

检测出已经分割出的图像的分类

2 使用pytorch

pytorch 非常简单就可以做到训练和加载

2.1 准备数据

在这里插入图片描述
如上图所示,用来训练的文件放在了train中,验证的文件放在val中,train.txt 和 val.txt 分别放文件名称和分类类别,然后我们在代码中写名字就行

里面我就为了做一个例子,放了两种文件,1 是 卡宴保时捷,2 是工程车,如下图所示
在这里插入图片描述
train.txt 如下图所示
在这里插入图片描述
val.txt 也是同样如此

3 show me the code

3.1 装载数据类

新增一个loaddata.py 文件

import torch
import random
from PIL import Image
class LoadData(torch.utils.data.Dataset):def __init__(self, root, datatxt, transform=None, target_transform=None):super(LoadData, self).__init__()file_txt = open(datatxt,'r')imgs = []for line in file_txt:line = line.rstrip()words = line.split('|')imgs.append((words[0], words[1]))self.imgs = imgsself.root = rootself.transform = transformself.target_transform = target_transformdef __getitem__(self, index):random.shuffle(self.imgs)name, label = self.imgs[index]img = Image.open(self.root + name).convert('RGB')if self.transform is not None:img = self.transform(img)label = int(label)return img, labeldef __len__(self):return len(self.imgs)

LoadData 类是从torch.util.data.Dataset上继承下来的,需要一个transform类输入,实际上就是转化大小

3.2 网络类

定义一个网络类,只有两个输出

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3)self.pool = nn.MaxPool2d((2, 2))self.pool1 = nn.MaxPool2d((2, 2))self.conv2 = nn.Conv2d(16, 32, 3)self.fc1 = nn.Linear(36*36*32, 120)self.fc2 = nn.Linear(120, 60)self.fc3 = nn.Linear(60, 2)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool1(F.relu(self.conv2(x)))x = x.view(-1, 36*36*32)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x

3.3 主要流程

import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim
from loaddata import LoadData
from modelnet import Netdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)classes = ['工程车','卡宴']
transform = transforms.Compose([transforms.Resize((152, 152)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_data=LoadData(root ='./data/train/',datatxt='./data/'+'train.txt',transform=transform)
test_data=LoadData(root ='./data/val/',datatxt='./data/'+'val.txt',transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=2, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=2)def imshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)for epoch in range(10):running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 200 == 0:print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 200))running_loss = 0.0print('Finished Training')PATH = './test.pth'
torch.save(net.state_dict(), PATH)net = Net()
net.load_state_dict(torch.load(PATH))correct = 0
total = 0
with torch.no_grad():for data in test_loader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))

在这里插入图片描述
如上图所示,epoch为5时精确度为80%,为10时精确度为100%,各位不要当真,这这是训练集里面的数据集做识别,并不是真的精确度。

3.4 识别代码

import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from modelnet import NetPATH = './test.pth'
transform = transforms.Compose([transforms.Resize((152, 152)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])net = Net()
net.load_state_dict(torch.load(PATH))img = Image.open("./data/val/102.jpg").convert('RGB')
img = transform(img)
with torch.no_grad():outputs = net(img)_, predicted = torch.max(outputs.data, 1)print("the 102 img lable is ",predicted)

如下图所示,102 为卡宴识别为1 正确
在这里插入图片描述

后记

后面我们准备是从视频中传递过来图像进行分类,同时使用我们的工具VT解码视频后进行内存共享来生成图像,而不是从磁盘加载。要用到我们的c++ 解码工具,和pytorch进行交互
以下是第一篇文章:视频与AI,与进程交互(一)
VT 工具准备开源,端午节节后开出来

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/4675.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis之数据类型String、List、Hash、Set、Sorted Set(详细)

一、String数据类型 1、SET/GET/APPEND/STRLEN (1) APPEND (2) SET/STRLEN 2、 INCR/ DECR/INCRBY/DECRBY (1)INCR/ DECR (2) INCRBY/DECRBY INCRBY key increment&#xff1…

selenium通过xpath定位text换行的元素

DOM元素(该元素是换行的,不能通过普通xpath定位): 可使用下面xpath定位该div //div[./text()/following-sibling::text()"点" and ./text()"5"] 解释一下就是:定位“子节点的text是[5] 且 子节点…

【Java项目】解决请求路径上明文ID传输导致可能被攻击的方法

文章目录 问题思路解决 问题 这个问题是我公司的一个小业务问题,问题来源于我们发送请求的时候,请求路径上携带的是明文,比如http://xxx/xxx/id12345,那么别有用心的人就可能会推测出id的生成策略,导致遍历id&#xf…

Qt编写视频监控系统79-四种界面导航栏的设计

一、前言 最初视频监控系统按照二级菜单的设计思路,顶部标题栏一级菜单,左侧对应二级菜单,最初采用图片在上面,文字在下面的按钮方式展示,随着功能的增加,二级菜单越来越多,如果都是这个图文上…

数据库表的操作

目录 前言 1.创建表 2.查看表 2.1查看表结构 2.2查看表中插入的数据 3.修改表 4.删除表 总结 前言 前面已经介绍了对数据库的操作,今天我们介绍的是数据库表的操作,数据库表简单可以理解为存储数据的介质。有了这个认识之后,下面我们…

21.RocketMQ源码之NameServer的路由管理和架构设计

highlight: arduino-light NameServer 路由管理 Broker消息服务器在启动的时向所有NameServer注册。 消息生产者Producer在发送消息之前先从NameServer获取Broker服务器地址列表然后根据负载均衡算法从列表中选择一台服务器进行发送。 NameServer与每台Broker保持长连接&#x…

如何在Microsoft Excel中快速筛选数据

你通常如何在 Excel 中进行筛选?在大多数情况下,通过使用自动筛选,以及在更复杂的场景中使用高级过滤器。 使用自动筛选或 Excel 中的内置比较运算符(如“大于”和“前10项”)来显示所需数据并隐藏其余数据。筛选单元格或表范围中的数据后,可以重新应用筛选器以获取最新…

移动端微信小程序学习

目录 小程序和web端的不同 小程序的宿主环境 通信 组件 视图容器​编辑 text组件 button image ​编辑 API api三大分类 模板语法 事件绑定 ​编辑 事件传参​编辑 bindinput 条件渲染 列表渲染 ​编辑 全局配置 window 页面配置 网络数据请求 ​编辑 GET请求 POST…

从电源 LED 读取智能手机的秘密?

研究人员设计了一种新的攻击方法,通过记录读卡器或智能手机打开时的电源 LED,使用 iPhone 摄像头或商业监控系统恢复存储在智能卡和智能手机中的加密密钥。 众所周知,这是一种侧信道攻击。 通过密切监视功耗、声音、电磁辐射或执行操作所需…

轻松生成高质量用例的API接口工具

1、前言 随着自动化测试技术的普及,已经有很多公司或项目,多多少少都会进行自动化测试。 目前本部门的自动化测试以接口自动化为主,接口用例采用 Excel 进行维护,按照既定的接口用例编写规则,对于功能测试人员来说只…

置换检验临界值

置换检验和t检验一样,会有统计值和P值。 置换检验的统计值记为Z值 其中这个Z和t检验的t一样,是有大小分别的。 例如b为1和2的分类变量,那么Z正值代表1大于2。 我们知道t检验的t值换算成P值,是需要自由度的。 例如在这个数据中&a…

【C#】文件拖拽,获取文件路径

系列文章 【C#】编号生成器(定义单号规则、固定字符、流水号、业务单号) 本文链接:https://blog.csdn.net/youcheng_ge/article/details/129129787 【C#】日期范围生成器(开始日期、结束日期) 本文链接:h…