pytorch+CRNN实现

最近接触了一个仪表盘识别的项目,简单调研以后发现可以用CRNN来做。但是手边缺少仪表盘数据集,就先用ICDAR2013试了一下。
在这里插入图片描述
结果遇到了一系列坑。为了不使读者和自己在以后的日子继续遭罪。我把正确的代码发到下面了。
1)超参数请不要调整!!!!CRNN前期训练极其离谱,需要良好的调参,loss才会慢慢下降。
在这里插入图片描述
我给出了一个训练曲线,可以看到确实贼几把怪,七拐八拐的。

2)千万不要用百度开源的那个ctc!!!

网络代码:

#crnn.py
import torch.nn as nn
import torch.nn.functional as Fclass BidirectionalLSTM(nn.Module):# Inputs hidden units Outdef __init__(self, nIn, nHidden, nOut):super(BidirectionalLSTM, self).__init__()self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)self.embedding = nn.Linear(nHidden * 2, nOut)def forward(self, input):recurrent, _ = self.rnn(input)T, b, h = recurrent.size()t_rec = recurrent.view(T * b, h)output = self.embedding(t_rec)  # [T * b, nOut]output = output.view(T, b, -1)return outputclass CRNN(nn.Module):#                   32    1   37     256def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):super(CRNN, self).__init__()assert imgH % 16 == 0, 'imgH has to be a multiple of 16'ks = [3, 3, 3, 3, 3, 3, 2]ps = [1, 1, 1, 1, 1, 1, 0]ss = [1, 1, 1, 1, 1, 1, 1]nm = [64, 128, 256, 256, 512, 512, 512]cnn = nn.Sequential()def convRelu(i, batchNormalization=False):nIn = nc if i == 0 else nm[i - 1]nOut = nm[i]cnn.add_module('conv{0}'.format(i),nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))if batchNormalization:cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))if leakyRelu:cnn.add_module('relu{0}'.format(i),nn.LeakyReLU(0.2, inplace=True))else:cnn.add_module('relu{0}'.format(i), nn.ReLU(True))convRelu(0)cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))  # 64x16x64convRelu(1)cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))  # 128x8x32convRelu(2, True)convRelu(3)cnn.add_module('pooling{0}'.format(2),nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 256x4x16convRelu(4, True)convRelu(5)cnn.add_module('pooling{0}'.format(3),nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x16convRelu(6, True)  # 512x1x16self.cnn = cnnself.rnn = nn.Sequential(BidirectionalLSTM(512, nh, nh),BidirectionalLSTM(nh, nh, nclass))def forward(self, input):# conv features#print('---forward propagation---')conv = self.cnn(input)b, c, h, w = conv.size()assert h == 1, "the height of conv must be 1"conv = conv.squeeze(2) # b *512 * widthconv = conv.permute(2, 0, 1)  # [w, b, c]output = F.log_softmax(self.rnn(conv), dim=2)return output

训练:

#train.py
import os
import torch
import cv2
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequenceimport crnn
import time
import re
import matplotlib.pyplot as plt
dic={" ":0,"a":1,"b":2,"c":3,"d":4,"e":5,"f":6,"g":7,"h":8,"i":9,"j":10,"k":11,"l":12,"m":13,"n":14,"o":15,"p":16,"q":17,"r":18,"s":19,"t":20,"u":21,"v":22,"w":23,"x":24,"y":25,"z":26,"A":27,"B":28,"C":29,"D":30,"E":31,"F":32,"G":33,"H":34,"I":35,"J":36,"K":37,"L":38,"M":39,"N":40,"O":41,"P":42,"Q":43,"R":44,"S":45,"T":46,"U":47,"V":48,"W":49,"X":50,"Y":51,"Z":52}STR=" abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
n_class=53
label_sources=r"E:\machine_learning\instrument\icdar_2013\Challenge2_Test_Task1_GT"
image_sources=r"E:\machine_learning\instrument\icdar_2013\Challenge2_Test_Task12_Images"
use_gpu = True
learning_rate = 0.0001
max_epoch = 100
batch_size = 20
# 调整图像大小和归一化操作
class resizeAndNormalize():def __init__(self, size, interpolation=cv2.INTER_LINEAR):# 注意对于opencv,size的格式是(w,h)self.size = sizeself.interpolation = interpolation# ToTensor属于类  """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.self.toTensor = transforms.ToTensor()def __call__(self, image):# (x,y) 对于opencv来说,图像宽对应x轴,高对应y轴image = cv2.resize(image, self.size, interpolation=self.interpolation)# 转为tensor的数据结构image = self.toTensor(image)# 对图像进行归一化操作#image = image.sub_(0.5).div_(0.5)return imagedef load_data(label_folder,image_folder,label_suffix_name=".txt",image_suffix_name=".jpg"):image_file,label_file,num_file=[],[],[]for parent_folder, _, file_names in os.walk(label_folder):# 遍历当前子文件夹中的所有文件for file_name in file_names:# 只处理图片文件# if file_name.endswith(('jpg', 'jpeg', 'png', 'gif')):#提取jpg、jpeg等格式的文件到指定目录if file_name.endswith((label_suffix_name)):  # 提取json格式的文件到指定目录# 构造源文件路径和目标文件路径a,b=file_name.split("gt_")c,d=b.split(label_suffix_name)image_name=image_folder + "\\" + c + image_suffix_nameif os.path.exists(image_name):label_name = label_folder + "\\" + file_nametxt=open(label_name,'rb')txtl=txt.readlines()for line in range(len(txtl)):image_file.append(image_name)label_file.append(label_name)num_file.append(line)return image_file,label_file,num_filedef zl2lable(zl):label_list=[]for str in zl:label_list.append(dic[str])return label_listclass NewDataSet(Dataset):def __init__(self, label_source,image_source,train=True):super(NewDataSet, self).__init__()self.image_file,self.label_file,self.num_file= load_data(label_source,image_source)def __len__(self):return len(self.image_file)def __getitem__(self, index):txt = open(self.label_file[index], 'rb')img=cv2.imread(self.image_file[index],cv2.IMREAD_GRAYSCALE)wordL = txt.readlines()word=str(wordL[self.num_file[index]])pl = re.findall(r'\d+',word)zl = re.findall(r"[a-zA-Z]+", word)[1]  #1#img tensorx1, y1, x2, y2 = pl[:4]img= img[int(y1):int(y2),int(x1):int(x2), ](height, width)=img.shape# 由于crnn网络输入图像的高为32,故需要resize原始图像的heightsize_height = 32# ratio = 32 / float(height)size_width =100transform = resizeAndNormalize((size_width, size_height))# 图像预处理imageTensor = transform(img)#label tensorl = zl2lable(zl)labelTensor = torch.IntTensor(l)return imageTensor,labelTensorclass CRNNDataSet(Dataset):def __init__(self, imageRoot, labelRoot):self.image_root = imageRootself.image_dict = self.readfile(labelRoot)self.image_name = [fileName for fileName, _ in self.image_dict.items()]def __getitem__(self, index):image_path = os.path.join(self.image_root, self.image_name[index])keys = self.image_dict.get(self.image_name[index])label = [int(x) for x in keys]image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)# if image is None:#     return None,None(height, width) = image.shape# 由于crnn网络输入图像的高为32,故需要resize原始图像的heightsize_height = 32ratio = 32 / float(height)size_width = int(ratio * width)transform = resizeAndNormalize((size_width, size_height))# 图像预处理image = transform(image)# 标签格式转换为IntTensorlabel = torch.IntTensor(label)return image, labeldef __len__(self):return len(self.image_name)def readfile(self, fileName):res = []with open(fileName, 'r') as f:lines = f.readlines()for line in lines:res.append(line.strip())dic = {}total = 0for line in res:part = line.split(' ')# 由于会存在训练过程中取图像的时候图像不存在导致异常,所以在初始化的时候就判断图像是否存在if not os.path.exists(os.path.join(self.image_root, part[0])):print(os.path.join(self.image_root, part[0]))total += 1else:dic[part[0]] = part[1:]print(total)return dictrainData =NewDataSet(label_sources,image_sources)trainLoader = DataLoader(dataset=trainData, batch_size=1, shuffle=True, num_workers=0)# valData = CRNNDataSet(imageRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\images\\",
#                       labelRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\lables\data_t.txt")
#
# valLoader = DataLoader(dataset=valData, batch_size=1, shuffle=True, num_workers=1)#
# def decode(preds):
#     pred = []
#     for i in range(len(preds)):
#         if preds[i] != 5989 and ((i == 5989) or (i != 5989 and preds[i] != preds[i - 1])):
#             pred.append(int(preds[i]))
#     return pred
#
#
def toSTR(l):str_l=[]if isinstance(l, int):l=[l]for i in range(len(l)):str_l.append(STR[l[i]])return str_l
def toRES(l):new_l=[]new_str=' 'for i in range(len(l)):if(l[i]==' '):new_str = ' 'continueelif new_str!=l[i]:new_l.append(l[i])new_str=l[i]return new_ldef val(model=torch.load("pytorch-crnn.pth")):# 将模式切换为验证评估模式loss_func = torch.nn.CTCLoss(blank=0, reduction='mean')model.eval()test_n=10for i, (data, label) in enumerate(trainLoader):if(i>test_n):break;output = model(data.cuda())pred_label=output.max(2)[1]input_lengths = torch.IntTensor([output.size(0)] * int(output.size(1)))target_lengths = torch.IntTensor([label.size(1)] * int(label.size(0)))# forward(self, log_probs, targets, input_lengths, target_lengths)#log_probs = output.log_softmax(2).requires_grad_()targets = label.cuda()loss = loss_func(output.cpu(), targets.cpu(), input_lengths, target_lengths)pred_l=np.array(pred_label.cpu().squeeze()).tolist()label_l=np.array(targets.cpu().squeeze()).tolist()print(i,":",loss,"pred:",toRES(toSTR(pred_l)),"label_l",toSTR(label_l))def train():model = crnn.CRNN(32, 1, n_class, 256)if torch.cuda.is_available() and use_gpu:model.cuda()loss_func = torch.nn.CTCLoss(blank=0,reduction='mean')optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,betas=(0.9, 0.999))lossTotal = 0.0k = 0printInterval = 100start_time = time.time()loss_list=[]total_list=[]for epoch in range(max_epoch):n=0data_list = []label_list = []label_len=[]for i, (data, label) in enumerate(trainLoader):#data_list.append(data)label_list.append(label)label_len.append(label.size(1))n=n+1if n%batch_size!=0:continuek=k+1data=torch.cat(data_list, dim=0)data_list.clear()label = torch.cat(label_list, dim=1).squeeze(0)label_list.clear()target_lengths=torch.tensor(np.array(label_len))label_len.clear()# 开启训练模式model.train()if torch.cuda.is_available and use_gpu:data = data.cuda()loss_func = loss_func.cuda()label = label.cuda()output = model(data)log_probs = output# example 建议使用这样,貌似直接把output送进去loss fun也没发现什么问题#log_probs = output.log_softmax(2).requires_grad_()targets = label.cuda()input_lengths = torch.IntTensor([output.size(0)] * int(output.size(1)))# forward(self, log_probs, targets, input_lengths, target_lengths)#targets =torch.zeros(targets.shape)loss = loss_func(log_probs.cpu(), targets, input_lengths, target_lengths)/batch_sizelossTotal += float(loss)print("epoch:",epoch,"num:",i,"loss:",float(loss))loss_list.append(float(loss))if k % printInterval == 0:print("[%d/%d] [%d/%d] loss:%f" % (epoch, max_epoch, i + 1, len(trainLoader), lossTotal / printInterval))total_list.append( lossTotal / printInterval)lossTotal = 0.0torch.save(model, 'pytorch-crnn.pth')optimizer.zero_grad()loss.backward()optimizer.step()plt.figure()plt.plot(loss_list)plt.savefig("loss.jpg")plt.clf()plt.figure()plt.plot(total_list)plt.savefig("total.jpg")end_time = time.time()print("takes {}s".format((end_time - start_time)))return modelif __name__ == '__main__':train()

测试结果如下:
在这里插入图片描述
最后给一些参考文献:
https://www.cnblogs.com/azheng333/p/7449515.html
https://blog.csdn.net/wzw12315/article/details/106643182

另外给出数据集和我训练好的模型:
链接:https://pan.baidu.com/s/1-jTA22bLKv2ut_1EJ1WMKA?pwd=jvk8
提取码:jvk8

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/28859.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

时序数据库 TDengine 与金山云两大产品完成兼容互认证

万物互联时代,企业数字化转型和政企上云如火如荼。在云计算迎来重大发展机遇的同时,数据库在企业数字化转型中也扮演着重要的角色——随着业务量的激增,数据库的弹性扩容、容灾备份等需求逐渐显现,在此挑战下,时序数据…

Vue项目实现在线预览pdf,并且可以批量打印pdf

最近遇到一个需求,就是要在页面上呈现pdf内容,并且还能用打印机批量打印pdf,最终效果如下: 当用户在列表页面,勾选中两条数据后,点击“打印表单”按钮之后,会跳到如下的预览页面: 预览页面顶部有个吸顶的效果,然后下方就展示出了2个pdf文件对应的内容,我们接着点击“…

Java分布式项目常用技术栈简介

Spring-Cloud-Gateway : 微服务之前架设的网关服务,实现服务注册中的API请求路由,以及控制流速控制和熔断处理都是常用的架构手段,而这些功能Gateway天然支持 运用Spring Boot快速开发框架,构建项目工程;并结合Spring…

热辐射的电磁波传播和相关Fluent设置

热辐射的本质是电磁波的辐射能和物质的内能之间相互转换。电磁波传播过程中,热辐射主要包括以下现象: 反射(reflection)折射(refraction)吸收(absorption)散射(scatteri…

前端框架Layui实现动态表格效果用户管理实例(对表格进行CRUD操作-附源码)

目录 一、前言 1.什么是表格 2.表格的使用范围 二、案例实现 1.案例分析 ①根据需求找到文档源码 ②查询结果在实体中没有该属性 2.dao层编写 ①BaseDao工具类 ②UserDao编写 3.Servlet编写 ①R工具类的介绍 ②Useraction编写 4.jsp页面搭建 ①userManage.jsp ②…

使用 uiautomator2+pytest+allure 进行 Android 的 UI 自动化测试

目录 前言: 介绍 pytest uiautomator2 allure 环境搭建 pytest uiautomator2 allure pytest 插件 实例 初始化 driver fixture 机制 数据共享 测试类 参数化 指定顺序 运行指定级别 重试 hook 函数 断言 运行 运行某个文件夹下的用例 运行某…

【极简,亲测,解决】Too many levels of symbolic links

前言(与内容无关) 帖子看多了,让我产生一种错觉,就是生产这些帖子的人都是机器人吗?是活着的吗?乱七八糟的转载和明显错误的结论太多了。 原因 原因是 链接的层数过多,已经产生了回路。 大概…

文心一言 VS 讯飞星火 VS chatgpt (62)-- 算法导论6.5 1题

文心一言 VS 讯飞星火 VS chatgpt (62)-- 算法导论6.5 1题 一、试说明 HEAP-EXTRACT-MAX在堆A(15,13,9,5,12,8,7,4,0,6,2&#xff0c…

酷炫无敌!10分钟学会制作3D园区大屏,职场新人也能秒变大神!

近年来随着大数据的飞速发展,各大行业都进行了一定的产业革新,智慧园区也逐渐进入企业视野并成为主流,不论大小企业,领导老板都要求员工制作出智慧园区的酷炫大屏,不顾及其中的技术难度,只想看到最终成果&a…

企业拥抱开源的同时,该如何做好风险防范?- 对话新思科技杨国梁

“软件供应链安全”相关文章合集 杨国梁 新思科技软件质量与安全部门高级安全架构师 当前,开源组件已成为软件应用程序中不可或缺的一部分。然而,随着开源软件数量的快速增长,应用领域的不断扩大,随之而来的安全问题也变得愈发严峻…

常见的网络攻击

​ 1.僵木蠕毒 攻击业内习惯把僵尸网络、木马、蠕虫、感染型病毒合称为僵木蠕毒。从攻击路径来看,蠕虫和感染型病毒通过自身的能力进行主动传播,木马则需要渠道来进行投放,而由后门木马(部分具备蠕虫或感染传播能力)构…

Hive概述

Hive 一 Hive基本概念 1 Hive简介 学习目标 - 了解什么是Hive - 了解为什么使用Hive####1.1 什么是 Hive Hive 由 Facebook 实现并开源,是基于 Hadoop 的一个数据仓库工具,可以将结构化的数据映射为一张数据库表,并提供 HQL(Hive SQL)查询…