pytorch车牌识别

目录

  • 使用pytorch库中CNN模型进行图像识别
    • 收集数据集
    • 定义CNN模型
      • 卷积层
      • 池化层
      • 全连接层
    • CNN模型代码
    • 使用模型

在这里插入图片描述

使用pytorch库中CNN模型进行图像识别

收集数据集

可以去找开源的数据集或者自己手做一个
最终整合成 类别分类的图片文件
在这里插入图片描述

定义CNN模型

卷积层

功能:提取特征

概念

  1. 卷积层输入层通道数

如果输入数据是彩色图像,那么通常情况下,输入数据具有三个通道(红、绿、蓝),因此第一个卷积层的输入通道数应该为3。
如果输入数据是灰度图像,那么输入通道数通常为 1。

  1. 卷积层输出层通道数

卷积层的输出通道数控制着该层提取的特征的数量和复杂度。更多的输出通道意味着网络可以学习更多种类的特征,但过多的输出通道数会导致复杂度和过拟合。

池化层

功能:使卷积层的特征更加明显,对图像进行降维压缩(舍弃无关特征,避免过拟合),提高神经网络的泛华能力。
问题:

  1. 最大池化操作

最大池化操作是一种常用的池化操作,用于减少特征图的空间维度并保留最重要的特征信息
在这里插入图片描述

# 定义最大池化层,池化窗口大小为 2x2,步幅为 2
max_pool_layer = nn.MaxPool2d(kernel_size=2, stride=2)

全连接层

将特征进行整合,然后归一化,对各种分类情况都输入一个概率,根据概率进行分类

CNN模型代码

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Dataset
# 进度条工具
from tqdm import tqdm# 数据集中的类别数
num_classes = len(os.listdir('./数据集'))
# 训练的轮数
num_epochs = 10
# 30次:['陕', '陕', 'U', 'U', '6', '6', '6', '6']
# 10次:['陕', 'A', 'D', '0', '6', '6', '6', '6']# 一、定义数据预处理和数据加载器
transform = transforms.Compose([# 固定图像大小transforms.Resize((64, 64)),# 将图像转换为灰度图像transforms.Grayscale(),# 将图像转换为张量transforms.ToTensor(),
])
# 使用ImageFolder定义数据集,标签为序号
train_dataset = ImageFolder(root='./数据集', transform=transform)
# 数据加载器,每个批次包含32张图像
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 二、定义 CNN 模型
class CNNModel(nn.Module):def __init__(self):super(CNNModel, self).__init__()# 卷积层1  1代表单通道,黑白;32代表输出通道;3代表3*3的卷积核, 1代表在最外围补一圈0self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)# 池化层1  最大池化操作,2代表尺寸减半self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 卷积层2 ,32对于卷积层1的输出通道数self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)# 全连接层 64输出通道数,16*16代表压缩后的尺寸,生成长度128向量self.fc1 = nn.Linear(64 * 16 * 16, 128)self.fc2 = nn.Linear(128, num_classes)# 前向传播 返回输出结果def forward(self, x):# 卷积1x = self.conv1(x)# 激活函数/激化函数 引入非线性变化,增强神经网络复杂性x = torch.relu(x)# 池化x = self.pool(x)x = self.pool(torch.relu(self.conv2(x)))x = x.view(-1, 64 * 16 * 16)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 三、初始化模型、损失函数和优化器
model = CNNModel()
criterion = nn.CrossEntropyLoss()
# 学习率一般设0.01
optimizer = optim.SGD(model.parameters(), lr=0.01)# 四、只要当主文件运行时候,才训练模型
if __name__ == "__main__":for epoch in range(num_epochs):running_loss = 0.0print(f'Epoch : {epoch + 1}/{num_epochs}')# 显示每轮的进度条for images, labels in tqdm(train_loader):#  将优化器中存储的之前计算的梯度归零optimizer.zero_grad()# 将输入图像数据 images 输入到模型中进行前向传播,得到模型的输出outputs = model(images)# 损失函数 criterion 计算模型 输出 与 真实标签 之间的损失值。loss = criterion(outputs, labels)# 对损失值进行反向传播,计算模型参数的梯度loss.backward()# 据优化算法(梯度下降)更新模型参数,最小化损失函数optimizer.step()running_loss += loss.item()# 输出每个 epoch 的平均损失epoch_loss = running_loss / len(train_loader)print(f'Epoch {epoch + 1} loss: {epoch_loss:.4f}')# 保存模型torch.save(model.state_dict(), 'cnn_model.pt')

使用模型

import torch
from PIL import Image
from torch.utils.data import dataset
from cnn_model import transform, train_dataset, CNNModel# 加载整个模型
model = CNNModel()
# 将模型设置为评估模式
model.eval()
checkpoint = torch.load('./cnn_model.pt')
model.load_state_dict(checkpoint)# 使用模型进行预测,识别单个文字图片
def predict_image(image_path):image = Image.open(image_path)# 转换图片格式image = transform(image)# 只进行前向传播with torch.no_grad():output = model(image)# ImageFolder输出的标签是文件序号,argmax找到张量output中的最大值predicted_idx = torch.argmax(output).item()print(predicted_idx)# 将输出转换成对应序号的文件名if predicted_idx < len(train_dataset.classes) :predicted_label = train_dataset.classes[predicted_idx]return predicted_labelelse:return "null"

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/616557.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Matlab之过球面一点的平面方程

这篇文章描述2件事情&#xff1a; 1、已知球面上任意点&#xff0c;求过该点、地心、与北极点的平面方程&#xff08;即过该点的经线平面方程&#xff09;&#xff1b; 2、绕过球心的任意轴旋转平面得到新平面的方程 一、已知球面上任意点&#xff0c;求过该点、地心、与北极点…

NPM 命令备忘单

NPM 简介 Node Package Manager (NPM) 是 Node.js 环境中不可或缺的命令行工具&#xff0c;充当包管理器来安装、更新和管理 Node.js 应用程序的库、包和模块。对于每个 Node.js 开发人员来说&#xff0c;无论他们的经验水平如何&#xff0c;它都是一个关键工具。 NPM 的主要…

MyBatis批量插入的五种方式

MyBatis批量插入的五种方式: 一、准备工作 1、导入pom.xml依赖 <dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><!-- MySQL驱动依赖 --…

文件操作(C语言)

目录 1.为什么使用文件&#xff1f; 2.什么是文件&#xff1f; 2.1程序文件 2.2数据文件 2.3文件名 3.二进制文件和文本文件 4.文件的打开和关闭 4.1流和标准流 4.1.1流 4.1.2标准流 4.2文件指针 4.3文件的打开和关闭 5.文件的顺序读写 5.1顺序读写函数介绍 5.2…

12 Php学习:魔术常量

PHP魔术常量 PHP 向它运行的任何脚本提供了大量的预定义常量。 不过很多常量都是由不同的扩展库定义的&#xff0c;只有在加载了这些扩展库时才会出现&#xff0c;或者动态加载后&#xff0c;或者在编译时已经包括进去了。 有八个魔术常量它们的值随着它们在代码中的位置改…

【Unity】常见性能优化

1 前言 本文将介绍下常用的Unity自带的常用优化工具&#xff0c;并介绍部分常用优化方法。都是比较基础的内容。 2 界面 2.1 Statistics窗口 可以简单查看Unity运行时的统计数据&#xff0c;当前一帧的性能数据。 2.1.1 Audio 音频相关内容。 Level&#xff1a;音量大小&a…

Java 原生代码获取服务器的网卡 Mac 地址、CPU序列号、主板序列号

1、概述 Java 可以获取服务器的网卡 Mac 地址、CPU 序列号、主板序列号等信息&#xff0c;用来做一些软件授权验证、设备管理等场景。 2、代码实现 package com.study.util;import java.net.InetAddress; import java.net.NetworkInterface; import java.util.Scanner;/*** …

gma 2.0.8 (2024.04.12) 更新日志

安装 gma 2.0.8 pip install gma2.0.8网盘下载&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1P0nmZUPMJaPEmYgixoL2QQ?pwd1pc8 提取码&#xff1a;1pc8 注意&#xff1a;此版本没有Linux版&#xff01; 编译gma的Linux虚拟机没有时间修复&#xff0c;本期Linux版继…

理想大模型实习面试题6道|含解析

节前&#xff0c;我们星球组织了一场算法岗技术&面试讨论会&#xff0c;邀请了一些互联网大厂朋友、参加社招和校招面试的同学&#xff0c;针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。 汇总…

STM32外设配置以及一些小bug总结

USART RX的DMA配置 这里以UART串口1为例&#xff0c;首先点ADD添加RX和TX配置DMA&#xff0c;然后模式一般会选择是normal&#xff0c;这个模式是当DMA的计数器减到0的时候就不做任何动作了&#xff0c;还有一种循环模式&#xff0c;是计数器减到0之后&#xff0c;计数器自动重…

计算机毕业设计Python+Flask电商商品推荐系统 商品评论情感分析 商品可视化 商品爬虫 京东爬虫 淘宝爬虫 机器学习 深度学习 人工智能 知识图谱

一、选题背景与意义 1.国内外研究现状 国外研究现状&#xff1a; 亚马逊&#xff08;Amazon&#xff09;&#xff1a;作为全球最大的电商平台之一&#xff0c;亚马逊在数据挖掘和大数据方面具有丰富的经验。他们利用Spark等大数据技术&#xff0c;构建了一套完善的电商数据挖…

算法100例(持续更新)

算法100道经典例子&#xff0c;按算法与数据结构分类 1、祖玛游戏2、找下一个更大的值3、换根树状dp4、一笔画完所有边5、树状数组&#xff0c;数字1e9映射到下标1e56、最长回文子序列7、超级洗衣机&#xff0c;正负值传递次数8、Dijkstra9、背包问题&#xff0c;01背包和完全背…