文章目录
- 0 前期教程
- 1 什么是模型部署
- 2 怎么部署
0 前期教程
-
【YOLO】朴实无华的yolov5环境配置
-
【YOLO】yolov5训练自己的数据集
1 什么是模型部署
前期教程当中,介绍了yolov5环境的搭建以及如何利用yolov5进行模型训练和测试,虽然能够实现图片或视频的目标识别,但都是基于pytorch这个深度学习框架来实现的。仅仅是为了使用训练好的模型,就需要附加一个巨大的框架,这样程序会显得很臃肿,不够优雅。因此,摆脱对深度学习框架的依赖,是非常有必要的。此即深度学习模型的部署。
2 怎么部署
这里使用的是opencv的dnn模块,可以实现读取并使用深度学习模型。但是,这个模块不支持pytorch模型,即训练好的pt格式的文件,因此,使用该模型时,还需要先将pt文件转换为opencv能够读取的模型格式,即onnx。
模型格式的转换使用的是yolov5自带的export.py文件,它提供了多种常见深度学习框架对应的文件格式。老规矩,使用前先看文件开头的注释:
我们需要的是onnx格式,因此在运行前先安装onnx:
pip install onnx
然后运行export.py文件:
python export.py --weights 'C:\Users\Zeoy\Desktop\Code\Python\yolov5-master\runs\train\exp19\weights\best.pt' --include onnx
生成的onnx文件也在原best.pt所在文件夹下。
转换完毕,接下来就是使用,运行如下所示代码:
import cv2
import numpy as npclass Onnx_clf:def __init__(self, onnx:str='best.onnx', img_size=640, classlist:list=['bottle']) -> None:''' @func: 读取onnx模型,并进行目标识别@para onnx:模型路径img_size:输出图片大小,和模型直接相关classlist:类别列表@return: None'''self.net = cv2.dnn.readNet(onnx) # 读取模型self.img_size = img_size # 输出图片尺寸大小self.classlist = classlist # 读取类别列表def img_identify(self, img, ifshow=True) -> np.ndarray:''' @func: 图片识别@para img: 图片路径或者图片数组ifshow: 是否显示图片@return: 图片数组'''if type(img) == str: src = cv2.imread(img)else: src = imgheight, width, _ = src.shape #注意输出的尺寸是先高后宽_max = max(width, height)resized = np.zeros((_max, _max, 3), np.uint8)resized[0:height, 0:width] = src # 将图片转换成正方形,防止后续图片预处理(缩放)失真# 图像预处理函数,缩放裁剪,交换通道 img scale out_size swapRBblob = cv2.dnn.blobFromImage(resized, 1/255.0, (self.img_size, self.img_size), swapRB=True)prop = _max // self.img_size # 计算缩放比例self.net.setInput(blob) # 将图片输入到模型out = self.net.forward() # 模型输出# print(out.shape)out = np.array(out[0])out = out[out[:, 4] >= 0.5] # 利用numpy的花式索引,速度更快boxes = out[:, :4]confidences = out[:, 4]class_ids = np.argmax(out[:, 5:], axis=1)class_scores = np.max(out[:, 5:], axis=1)# out2 = out[0][out[0][:][4] > 0.5]# for i in out[0]: # 遍历每一个框# class_max_score = max(i[5:])# if i[4] < 0.5 or class_max_score < 0.25: # 过滤置信度低的目标# continue# boxes.append(i[:4]) # 获取目标框: x,y,w,h (x,y为中心点坐标)# confidences.append(i[4]) # 获取置信度# class_ids.append(np.argmax(i[5:])) # 获取类别id# class_scores.append(class_max_score) # 获取类别置信度indexes = cv2.dnn.NMSBoxes(boxes, confidences, 0.25, 0.45) # 非极大值抑制, 获取的是索引for i in indexes: # 遍历每一个目标, 绘制目标框box = boxes[i]class_id = class_ids[i]score = round(class_scores[i], 2)x1 = int((box[0] - 0.5*box[2])*prop)y1 = int((box[1] - 0.5*box[3])*prop)x2 = int((box[0] + 0.5*box[2])*prop)y2 = int((box[1] + 0.5*box[3])*prop)self.drawtext(src,(x1, y1), (x2, y2), self.classlist[class_id]+' '+str(score))if ifshow:dst = cv2.resize(src, (width//prop, height//prop))cv2.imshow('result', dst)cv2.waitKey(0)return srcdef video_identify(self, video_path:str) -> None:''' @func: 视频识别@para video_path: 视频路径@return: None'''cap = cv2.VideoCapture(video_path)fps = cap.get(cv2.CAP_PROP_FPS)# print(fps)while cap.isOpened():ret, frame = cap.read()#键盘输入空格暂停,输入q退出key = cv2.waitKey(1) & 0xffif key == ord(" "): cv2.waitKey(0)if key == ord("q"): breakif not ret: breakimg = self.img_identify(frame, False)cv2.imshow('result', img)# cv2.imshow('result', frame)if cv2.waitKey(int(10/fps)) == ord('q'):breakcap.release()cv2.destroyAllWindows()@staticmethoddef drawtext(image, pt1, pt2, text):''' @func: 根据给出的坐标和文本,在图片上进行绘制@para image: 图片数组; pt1: 左上角坐标; pt2: 右下角坐标; text: 矩形框上显示的文本,即类别信息@return: None'''fontFace = cv2.FONT_HERSHEY_COMPLEX_SMALL # 字体# fontFace = cv2.FONT_HERSHEY_COMPLEX # 字体fontScale = 1.5 # 字体大小line_thickness = 3 # 线条粗细font_thickness = 2 # 文字笔画粗细line_back_color = (0, 0, 255) # 线条和文字背景颜色:红色font_color = (255, 255, 255) # 文字颜色:白色# 绘制矩形框cv2.rectangle(image, pt1, pt2, color=line_back_color, thickness=line_thickness)# 计算文本的宽高: retval:文本的宽高; baseLine:基线与最低点之间的距离(本例未使用)retval, baseLine = cv2.getTextSize(text,fontFace=fontFace,fontScale=fontScale, thickness=font_thickness)# 计算覆盖文本的矩形框坐标topleft = (pt1[0], pt1[1] - retval[1]) # 基线与目标框上边缘重合(不考虑基线以下的部分)bottomright = (topleft[0] + retval[0], topleft[1] + retval[1])cv2.rectangle(image, topleft, bottomright, thickness=-1, color=line_back_color) # 绘制矩形框(填充)# 绘制文本cv2.putText(image, text, pt1, fontScale=fontScale,fontFace=fontFace, color=font_color, thickness=font_thickness)if __name__ == '__main__':clf = Onnx_clf()import tkinter as tkfrom tkinter.filedialog import askopenfilenameroot = tk.Tk()root.withdraw() # 隐藏主窗口source = askopenfilename(title="打开保存的图片或视频")if source.endswith('.jpg') or source.endswith('.png') or source.endswith('.bmp'):clf.img_identify(source)elif source.endswith('.mp4') or source.endswith('.avi'):print('视频识别中...按q退出')clf.video_identify(source)else:print('不支持的文件格式')