【模型训练】-图形验证码识别

针对网站中的图形验证码图片,进行反向的内容识别,支持数字和字母,不区分大小写。

​​​​​​​​​​​​​​数据集地址

数据格式如下:

1、依赖导入

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Imageimport numpy as np
import pickle as pkl
import matplotlib.pyplot as plt

2、数据集创建

class Dataset(Dataset):def __init__(self, img_dir):path_list = os.listdir(img_dir)# 获取文件夹绝对路径abspath = os.path.abspath(img_dir)self.img_list = [os.path.join(abspath, path) for path in path_list]self.transform = transforms.Compose([# 灰度化,配合 卷积网络初始通过 1# transforms.Grayscale(), transforms.ToTensor(),])def __len__(self):return len(self.img_list)def __getitem__(self, idx):path = self.img_list[idx]label = os.path.basename(path).split('.')[0].lower().strip()img = Image.open(path).convert('RGB')img_tensor = self.transform(img)return img_tensor, label

3、创建crnn卷积循环神经网络

stride 步长 

padding 完成卷积后是否填充空白

MaxPool2d :减少数据空间大小,池化窗口的大小,通常设置为2×2。减少参数数量和计算量,同时也能提高模型的鲁棒性。

BatchNorm(512):对输入数据进行归一化处理,使得每个通道的数据均值为0,方差为1,提高模型的泛化能力

dropout:随机丢弃神经元的输出来减少模型的复杂度和过拟合的风险

nn.GRU:PyTorch中的一个函数,用于创建一个双向的GRU(门控循环单元)层。

参数解释如下:

  • 255:输入的特征维度。输入数据的特征维度为255
  • 255:隐藏状态的维度。隐藏状态的维度为255
  • bidirectional=True:表示是否使用双向GRU。如果设置为True,则使用双向GRU;如果设置为False,则使用单向GRU。
  • batch_first=True:表示输入数据的维度顺序。如果设置为True,则输入数据的维度顺序为(batch_size, sequence_length, feature_dim);如果设置为False,则输入数据的维度顺序为(sequence_length, batch_size, feature_dim)。
class CRNN(nn.Module):def __init__(self, vocab_size, dropout=0.5):super(CRNN, self).__init__()self.dropout = nn.Dropout(dropout)self.convlayer = nn.Sequential(# 如果预处理采用Grayscale 则 channel=1nn.Conv2d(3, 32, (3,3), stride=1, padding=1),# 激活函数,x小于0,y=0nn.ReLU(),nn.MaxPool2d((2,2), 2),nn.Conv2d(32, 64, (3,3), stride=1, padding=1),nn.ReLU(),nn.MaxPool2d((2,2), 2),nn.Conv2d(64, 128, (3,3), stride=1, padding=1),nn.ReLU(),nn.Conv2d(128, 256, (3,3), stride=1, padding=1),nn.ReLU(),nn.MaxPool2d((1,2), 2),nn.Conv2d(256, 512, (3,3), stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(512, 512, (3,3), stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(),nn.MaxPool2d((1,2), 2),nn.Conv2d(512, 512, (2,2), stride=1, padding=0),self.dropout)self.mapSeq = nn.Sequential(nn.Linear(1024, 256),self.dropout)self.lstm_0 = nn.GRU(256, 256, bidirectional=True)self.lstm_1 = nn.GRU(512, 256, bidirectional=True)self.out = nn.Sequential(nn.Linear(512, vocab_size),)def forward(self, x):x = self.convlayer(x)x = x.permute(0, 3, 1, 2)x = x.view(x.size(0), x.size(1), -1)x = self.mapSeq(x)x, _ = self.lstm_0(x)x, _ = self.lstm_1(x)x = self.out(x)return x.permute(1, 0, 2)

4、创建模型


class OCR:def __init__(self):self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')self.crnn = CRNN(VOCAB_SIZE).to(self.device)print('Model loaded to ', self.device)self.critertion = nn.CTCLoss(blank=0)self.char2idx, self.idx2char = self.char_idx()def char_idx(self):char2idx = {}idx2char = {}characters = CHARS.lower() + '-'for i, char in enumerate(characters):char2idx[char] = i + 1idx2char[i+1] = charreturn char2idx, idx2chardef encode(self, labels):length_per_label = [len(label) for label in labels] joined_label = ''.join(labels)joined_encoding = []for char in joined_label:joined_encoding.append(self.char2idx[char])return (torch.IntTensor(joined_encoding), torch.IntTensor(length_per_label)) def decode(self, logits):tokens = logits.softmax(2).argmax(2).squeeze(1)tokens = ''.join([self.idx2char[token]if token !=0 else '-'for token in tokens.numpy()])tokens = tokens.split('-')text = [char for batch_token in tokensfor idx, char in enumerate(batch_token)if char != batch_token[idx-1] or len(batch_token) == 1]    text = ''.join(text)  return textdef calculate_loss(self, logits, labels):encoded_labels, labels_len = self.encode(labels)logits_lens = torch.full(size=(logits.size(1),),fill_value = logits.size(0),dtype = torch.int32).to(self.device)return self.critertion(logits.log_softmax(2), encoded_labels,logits_lens, labels_len)def train_step(self, optimizer, images, labels):logits = self.predict(images)optimizer.zero_grad()loss = self.calculate_loss(logits, labels)loss.backward()optimizer.step()return logits, lossdef val_step(self, images, labels):logits = self.predict(images)loss = self.calculate_loss(logits, labels)return logits, lossdef predict(self, img):return self.crnn(img.to(self.device))def train(self, num_epochs, optimizer, train_loader, val_loader, print_every = 2):train_losses, valid_losses = [],[]for epoch in range(num_epochs):tot_train_loss = 0self.crnn.train()for i, (images, labels) in enumerate(train_loader):logits, train_loss = self.train_step(optimizer, images, labels)tot_train_loss += train_loss.item()with torch.no_grad():tot_val_loss = 0self.crnn.eval()for i, (images, labels) in enumerate(val_loader):logits, val_loss = self.val_step(images, labels)tot_val_loss += val_loss.item()train_loss = tot_train_loss / len(train_loader.dataset)valid_loss = tot_val_loss / len(val_loader.dataset)train_losses.append(train_loss)valid_losses.append(valid_loss)if epoch % print_every == 0:print('Epoch [{:5d}/{:5d}] | train loss {:6.4f} | val loss {:6.4f}'.format(epoch + 1, num_epochs, train_loss, val_loss))                return train_losses, valid_losses

5、开启训练


TRAIN_DIR = '../data/train'
VAL_DIR = '../data/val'# batch_size lr 参数值训练,得到的结果较合适
BATCH_SIZE = 8
N_WORKERS = 0
EPOCHS = 20CHARS ='abcdefghijklmnopqrstuvwxyz0123456789'
VOCAB_SIZE = len(CHARS) + 1lr = 0.02
# 权重衰减
weight_decay = 1e-5
# 下降幅度
momentum = 0.7train_dataset = Dataset(TRAIN_DIR)
val_dataset = Dataset(VAL_DIR)train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE,num_workers = N_WORKERS, shuffle=True
)val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,num_workers=N_WORKERS, shuffle=False
)ocr = OCR()optimizer = optim.SGD(ocr.crnn.parameters(), lr =lr, nesterov=True,weight_decay=weight_decay, momentum=momentum
) train_losses, val_losses = ocr.train(EPOCHS, optimizer, train_loader, val_loader, print_every=1)

6、随机采样,验证模型


sample_result = []for i in range(10):idx = np.random.randint(len(val_dataset))img, label = val_dataset.__getitem__(idx)logits = ocr.predict(img.unsqueeze(0))pred_text = ocr.decode(logits.cpu())sample_result.append((img, label, pred_text))fig = plt.figure(figsize=(17,5))    
for i in range(10):ax = fig.add_subplot(2, 5, i+1, xticks=[], yticks=[])img, label, pred_text = sample_result[i]title = f'Truth: {label} | Pred: {pred_text}'ax.imshow(img.permute(1,2, 0))ax.set_title(title)plt.show()

7、输出统计图

plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Valid Loss')
plt.title('Loss stats')
plt.legend()
plt.show()

8、外部数据验证

trans = transforms.Compose([# 取决于与处理中是否也做相同处理transforms.Grayscale(),# 原始数据集图片尺寸transforms.Resize([50, 200]),transforms.ToTensor(),
])image = Image.open('../data/123.png').convert('RGB')
tensor_img = trans(image)
result = ocr.predict(tensor_img.unsqueeze(0))
text = ocr.decode(result.cpu())
print(text)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/512504.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于springboot的新闻推荐系统论文

基于springbootvue的新闻推荐系统 摘要 随着信息互联网购物的飞速发展,国内放开了自媒体的政策,一般企业都开始开发属于自己内容分发平台的网站。本文介绍了新闻推荐系统的开发全过程。通过分析企业对于新闻推荐系统的需求,创建了一个计算机…

数据删除

目录 数据删除 删除员工编号为 7369 的员工信息 删除若干个数据 删除公司中工资最高的员工 Oracle从入门到总裁:​​​​​​https://blog.csdn.net/weixin_67859959/article/details/135209645 数据删除 删除数据就是指删除不再需要的数据 delete from 表名称 [where 删…

#QT(网络编程-UDP)

1.IDE:QTCreator 2.实验:UDP 不分客户端和服务端 3.记录 (1)做一个UI界面 (2)编写open按钮代码进行测试(用网络调试助手测试) (3)完善其他功能测试 4.代码 …

gazebo平衡车模拟

gazebo和Ros中的平衡车模拟(Noetic) 控制原理 使用说明 在URDF模型中使用gazebo的 imu 插件获取平衡车姿态从 /joint_state 话题消息获取两轮的速度,相当于电机编码器速度环和直立环使用 串级PID 控制,框图如下:转向环…

PHP+MySQL实现后台管理系统增删改查之够用就好

说明 最近要给博客弄个后台,不想搞得很复杂,有基本的增删改查就够了,到网上找了一圈发现这个不错,很实用,希望可以帮到大家,需要的朋友评论区留下邮箱,我安排发送。 演示效果 项目介绍 本项目…

吴恩达机器学习笔记:第5周-9 神经网络的学习1(Neural Networks: Learning)

目录 9.1 代价函数9.2 反向传播算法9.3 反向传播算法的直观理解 9.1 代价函数 首先引入一些便于稍后讨论的新标记方法: 假设神经网络的训练样本有𝑚个,每个包含一组输入𝑥和一组输出信号𝑦,𝐿…

2024大广赛参赛流程分享

自2005年第一届以来,全国大学生广告艺术大赛(以下简称大广赛)遵循“促进教育改革、启迪智慧、增强能力、提高素质、培养人才”的竞赛宗旨,成功举办了14届15届大赛,共有1857所高校参加,100多万学生提交作品。…

【详识JAVA语言】String类oj练习

1. 第一个只出现一次的字符 class Solution { public int firstUniqChar(String s) {int[] count new int[256];// 统计每个字符出现的次数for(int i 0; i < s.length(); i){count[s.charAt(i)];}// 找第一个只出现一次的字符for(int i 0; i < s.length(); i){if(1 …

生产工厂数据中台解决方案:打造可视化平台,为工业搭建智慧大脑-亿发

制造数据中台是将企业现有的业务软件系统进行整合并打通&#xff0c;形成一套标准模块化框架&#xff0c;然后在此基础上构建一个统一的信息服务和应用平台。数据中台的建设涵盖了诸如 ERP&#xff08;供应链管理&#xff09;、MES&#xff08;制造执行管理&#xff09;、SRM&a…

激光雷达点云数据邻域特征计算理论知识学习

一、数学理论 &#xff08;一&#xff09;SVD奇异值分解&#xff08;Singular value decomposition&#xff09; 奇异值分解是线性代数中一种重要的矩阵分解&#xff0c;在信号处理、统计学等领域有重要应用。奇异值分解在某些方面与对称矩阵或埃尔米特矩阵基于特征向量的对角…

哪些型号的高速主轴适合PCB分板机

在选择适合PCB分板机的高速主轴时&#xff0c;SycoTec品牌提供了丰富的型号选择&#xff0c;主要型号包括4025 HY、4033 AC&#xff08;电动换刀&#xff09;、4033 AC-ESD、4033 DC-T和4041 HY-ESD等。 那么如何选择合适的PCB分板机高速主轴型号呢&#xff1f;在选择适合PCB分…

06. Nginx进阶-Nginx代理服务

proxy代理功能 正向代理 什么是正向代理&#xff1f; 正向代理&#xff08;forward proxy&#xff09;&#xff0c;一个位于客户端和原始服务器之间的服务器。 工作原理 为了从原始服务器获取内容&#xff0c;客户端向代理发送一个请求并指定目标&#xff08;即原始服务器…