OCR识别网络CRNN理解与Pytorch实现

CRNN是2015年的论文“An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition”提出的图像字符识别网络,也是目前工业界使用较为广泛的一个OCR网络。论文地址:https://arxiv.org/abs/1507.05717

1. 网络结构

CRNN是一个端到端可训练的网络,并且可处理任意长度的字符序列。CRNN得名于Convolutional Recurrent Neural Network,从名称即可看出,该网络包含了卷积网络和递归网络。实际上,CRNN由三部分组成,分别是卷积层部分(Convolutional layers)、递归层部分(Recurrent Layers)和转录层部分(Transcription Layers),如下图所示:

其中, 卷积层的作用是从输入图像中提取特征,递归层则对卷积层输出的feature maps进行预测,最后,转录层将递归层的预测结果翻译成文字标签序列。CNN和RNN可由同一个损失函数进行联合训练。

在图像输入CRNN之前,需要缩放到指定高度height,宽度无限制。卷积层输出的feature maps在送入RNN之前,从左到右生成一个feature vector序列,第i个feature vector为feature maps第i列的元素的级联。这样做的好处是,每个feature vector代表了原图像上一个矩形区域的特征(感受野),使得网络能够预测不同长度的字符序列。

RNN网络的优势在于,它能够有效利用序列的上下文信息进行预测,比分别预测单个字符有更好的精确度和稳定性,同时,它对输入序列的长度无限制,比单纯使用CNN网络更加灵活。

由于传统RNN存在梯度爆炸和梯度消失问题,因此,在这篇文章中,作者采用了LTSM(Long-Short Term Memory)来克服该问题。一个LSTM包含一个记忆单元(Memory Cell)和三个乘法门(Multiplicative gates),分别为输入门(input gate)、输出门(output gate)和遗忘门(forget gate),如下图所示:

由于基于图像的文字识别具有较强的前向和后向上下文信息,因此,使用双向LSTM(bidirectional LSTM)是一个合适的选择。 

转录层将RNN层的预测结果(用概率表示)映射到字符序列。 在实践中,存在两种转录模式,分别是基于词典的转录,和无词典转录。在基于词典的模式中,会选择词典中最高概率的标签进行预测;而无词典模式,预测则是在无任何词典的情况下进行的。

CRNN的具体网络结构及配置如下:

2. 代码实现 

网上找到一个CRNN的Pytorch实现,亲测好用,代码链接:CRNN Pytorch

网络定义:

import torch.nn as nnclass CRNN(nn.Module):def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):super(CRNN, self).__init__()assert imgH % 16 == 0, 'imgH has to be a multiple of 16'ks = [3, 3, 3, 3, 3, 3, 2]ps = [1, 1, 1, 1, 1, 1, 0]ss = [1, 1, 1, 1, 1, 1, 1]nm = [64, 128, 256, 256, 512, 512, 512]cnn = nn.Sequential()def convRelu(i, batchNormalization=False):nIn = nc if i == 0 else nm[i - 1]nOut = nm[i]cnn.add_module('conv{0}'.format(i),nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))if batchNormalization:cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))if leakyRelu:cnn.add_module('relu{0}'.format(i),nn.LeakyReLU(0.2, inplace=True))else:cnn.add_module('relu{0}'.format(i), nn.ReLU(True))convRelu(0)cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))  # 64x16x64convRelu(1)cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))  # 128x8x32convRelu(2, True)convRelu(3)cnn.add_module('pooling{0}'.format(2),nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 256x4x16convRelu(4, True)convRelu(5)cnn.add_module('pooling{0}'.format(3),nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x16convRelu(6, True)  # 512x1x16self.cnn = cnnself.rnn = nn.Sequential(BidirectionalLSTM(512, nh, nh),BidirectionalLSTM(nh, nh, nclass))def forward(self, input):# conv featuresconv = self.cnn(input)b, c, h, w = conv.size()assert h == 1, "the height of conv must be 1"conv = conv.squeeze(2)conv = conv.permute(2, 0, 1)  # [w, b, c]# rnn featuresoutput = self.rnn(conv)return output

 其中,Bidirectional LSTM的定义如下:

class BidirectionalLSTM(nn.Module):def __init__(self, nIn, nHidden, nOut):super(BidirectionalLSTM, self).__init__()self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)self.embedding = nn.Linear(nHidden * 2, nOut)def forward(self, input):recurrent, _ = self.rnn(input)T, b, h = recurrent.size()t_rec = recurrent.view(T * b, h)output = self.embedding(t_rec)  # [T * b, nOut]output = output.view(T, b, -1)return output

Demo是基于字典的转录方式,可以识别0~9的1-0个数字,以及a~z的26个字母。

import torch
from torch.autograd import Variable
import utils
import dataset
from PIL import Imageimport models.crnn as crnnmodel_path = './data/crnn.pth'  # 与训练模型路径
img_path = './data/demo.png'    # 测试图片路径
alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'    # 字典model = crnn.CRNN(32, 1, 37, 256)
if torch.cuda.is_available():model = model.cuda()
print('loading pretrained model from %s' % model_path)
model.load_state_dict(torch.load(model_path))converter = utils.strLabelConverter(alphabet)    # 定义字典转录函数transformer = dataset.resizeNormalize((100, 32))   # 图像预处理函数
image = Image.open(img_path).convert('L')
image = transformer(image)
if torch.cuda.is_available():image = image.cuda()
print('image size: ', image.shape)
image = image.view(1, *image.size())
image = Variable(image)model.eval()
preds = model(image)    # CRNN预测_, preds = preds.max(2)    # 找到最大概率所对应的index
preds = preds.transpose(1, 0).contiguous().view(-1)preds_size = Variable(torch.IntTensor([preds.size(0)]))
raw_pred = converter.decode(preds.data, preds_size.data, raw=True)   # 逐一输出预测字符,如a-----v--a-i-l-a-bb-l-e---
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)  # 输出最终预测结果,如available
print('%-20s => %-20s' % (raw_pred, sim_pred))

 demo执行结果:a-----v--a-i-l-a-bb-l-e--- => available  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/418485.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenHarmony AI框架开发指导

一、概述 1、 功能简介 AI业务子系统是OpenHarmony提供原生的分布式AI能力的子系统。AI业务子系统提供了统一的AI引擎框架,实现算法能力快速插件化集成。 AI引擎框架主要包含插件管理、模块管理和通信管理模块,完成对AI算法能力的生命周期管理和按需部…

检索增强(RAG)的方式---重排序re-ranking

提升RAG:选择最佳嵌入Embedding&重排序Reranker模型 检索增强生成(RAG)技术创新进展:自我检索、重排序、前瞻检索、系统2注意力、多模态RAG RAG的re-ranking指的是对初步检索出来的候选段落或者文章,通过重新排序的方式来提升检索质量。…

红包封面免费送1000个,你设计,我出额度

相信最近大家或多或少都知道了吧,腾讯又又又给大家,准确的说是给一年勤奋的公众号/视频号博主一个福利 根据不同博主的粉丝、更新频度以及作品质量,给力博主们免费制作红包封面的福利 比如我这个号,有6000额度 那这6000个&#…

从规则到神经网络:机器翻译技术的演化之路

文章目录 从规则到神经网络:机器翻译技术的演化之路一、概述1. 机器翻译的历史与发展2. 神经机器翻译的兴起3. 技术对现代社会的影响 二、机器翻译的核心技术1. 规则基础的机器翻译(Rule-Based Machine Translation, RBMT)2. 统计机器翻译&am…

【STM32】STM32学习笔记-I2C通信协议(31)

00. 目录 文章目录 00. 目录01. I2C简介02. I2C主要特点03. I2C硬件电路04. I2C时序基本单元05. I2C时序波形图06. 附录 01. I2C简介 I2C(Inter-Integrated Circuit)总线是一种由NXP(原PHILIPS)公司开发的两线式串行总线,用于连接…

【高等数学之定积分】

一、什么是定积分? 我们第一次新手司机开车从某一地方到(一数)家,自始至终保持着匀速行驶,那么这个过程所经历的路程是什么呢?用速度-时间函数图像来表示一下,我们发现其实就是其曲线下的面积。 第二次开车我们已经有了一定的经验&#xff…

vivado 定义板文件板

定义板文件板 &#xff1c;board&#xff1e;标记是板文件的根。它包括识别基本信息的属性关于董事会。 <board schema_version"2.1" vendor"xilinx.com" name"kc705" display_name"Kintex-7 KC705 Evaluation Platform" url&qu…

数据结构小项目----通讯录的实现(这里用链表实现) 超详细~~~~૮(˶ᵔ ᵕ ᵔ˶)ა

目录 Contact.h说明&#xff1a; 结构体与头文件的包含&#xff1a; ​编辑 函数在头文件的声明与定义&#xff1a; Contact.c中各个函数的实现&#xff1a; 1.检查链表中的数据是否满了&#xff0c;满了就扩容 2.链表的尾插 3.链表的删除 4.查找名字是否匹配 5.初始化通讯…

代码随想录二刷 |二叉树 | 将有序数组转换为二叉搜索树

代码随想录二刷 |二叉树 | 将有序数组转换为二叉搜索树 题目描述解题思路代码实现 题目描述 109.将有序数组转换为二叉搜索树 将一个按照升序排列的有序数组&#xff0c;转换为一棵高度平衡二叉搜索树。 本题中&#xff0c;一个高度平衡二叉树是指一个二叉树每个节点 的左右…

Zabbix 系统监控详解

1 介绍 1.1 摘要 本文深入浅出&#xff0c;切近实际运维应用&#xff0c;由 zabbix 3.4 版本入手&#xff0c;学习 zabbix 监控告警实现方式&#xff0c;由 zabbix 5.0 浅出实现快速部署、快速应用。本人从业多年&#xff0c;关注 zabbix 开源社区&#xff0c;以及 zabbix 官…

企业邮箱:定义、功能与优势一览

本文将为大家讲解&#xff1a;1、企业邮箱的定义&#xff1b;2、企业邮箱的主要功能特点&#xff1b;3、企业邮箱如何选择和部署&#xff1b;4、企业邮箱的运营与维护&#xff1b;5、企业邮箱在实际工作中的应用与挑战&#xff1b;6、2024年最新五大企业邮箱盘点 下面提到的功能…

【kali后续配置】Kali Linux 换源更新及配置SSH服务一文通(一键换源脚本)

在前面&#xff0c;我们已经下载并安装了Kali Linux 2023版本&#xff0c;因为一些事情的耽误&#xff0c;后面的一些操作教程没有发出来&#xff0c;今天给大家补上。 Kali Linux安装 前置准备&#xff1a; VMware安装kali Linux 镜像下载 kali源 官方源 deb http://htt…