PyTorch常用工具(2)预训练模型

文章目录

  • 前言
  • 2 预训练模型

前言

在训练神经网络的过程中需要用到很多的工具,最重要的是数据处理、可视化和GPU加速。本章主要介绍PyTorch在这些方面常用的工具模块,合理使用这些工具可以极大地提高编程效率。

由于内容较多,本文分成了五篇文章(1)数据处理(2)预训练模型(3)TensorBoard(4)Visdom(5)CUDA与小结。

整体结构如下:

  • 1 数据处理
    • 1.1 Dataset
    • 1.2 DataLoader
  • 2 预训练模型
  • 3 可视化工具
  • 3.1 TensorBoard
  • 3.2 Visdom
  • 4 使用GPU加速:CUDA
  • 5 小结

全文链接:

  1. PyTorch中常用的工具(1)数据处理
  2. PyTorch常用工具(2)预训练模型
  3. PyTorch中常用的工具(3)TensorBoard
  4. PyTorch中常用的工具(4)Visdom
  5. PyTorch中常用的工具(5)使用GPU加速:CUDA

2 预训练模型

除了加载数据,并对数据进行预处理之外,torchvision还提供了深度学习中各种经典的网络结构以及预训练模型。这些模型封装在torchvision.models中,包括经典的分类模型:VGG、ResNet、DenseNet及MobileNet等,语义分割模型:FCN及DeepLabV3等,目标检测模型:Faster RCNN以及实例分割模型:Mask RCNN等。读者可以通过下述代码使用这些已经封装好的网络结构与模型,也可以在此基础上根据需求对网络结构进行修改:

from torchvision import models
# 仅使用网络结构,参数权重随机初始化
mobilenet_v2 = models.mobilenet_v2()
# 加载预训练权重
deeplab = models.segmentation.deeplabv3_resnet50(pretrained=True)

下面使用torchvision中预训练好的实例分割模型Mask RCNN进行一次简单的实例分割:

In: from torchvision import modelsfrom torchvision import transforms as Tfrom torch import nnfrom PIL import Imageimport numpy as npimport randomimport cv2# 加载预训练好的模型,不存在的话会自动下载# 预训练好的模型保存在 ~/.torch/models/下面detection = models.detection.maskrcnn_resnet50_fpn(pretrained=True)detection.eval()def predict(img_path, threshold):# 数据预处理,标准化至[-1, 1],规定均值和标准差img = Image.open(img_path)transform = T.Compose([T.ToTensor(),T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])])img = transform(img)# 对图像进行预测pred = detection([img])# 对预测结果进行后处理:得到mask与bboxscore = list(pred[0]['scores'].detach().numpy())t = [score.index(x) for x in score if x > threshold][-1]mask = (pred[0]['masks'] > 0.5).squeeze().detach().cpu().numpy()pred_boxes = [[(i[0], i[1]), (i[2], i[3])] \for i in list(pred[0]['boxes'].detach().numpy())]pred_masks = mask[:t+1]boxes = pred_boxes[:t+1]return pred_masks, boxes

Transforms中涵盖了大部分对Tensor和PIL Image的常用处理,这些已在上文提到,本节不再详细介绍。需要注意的是转换分为两步,第一步:构建转换操作,例如transf = transforms.Normalize(mean=x, std=y);第二步:执行转换操作,例如output = transf(input)。另外还可以将多个处理操作用Compose拼接起来,构成一个处理转换流程。

In: # 随机颜色,以便可视化def color(image):colours = [[0, 255, 255], [0, 0, 255], [255, 0, 0]]R = np.zeros_like(image).astype(np.uint8)G = np.zeros_like(image).astype(np.uint8)B = np.zeros_like(image).astype(np.uint8)R[image==1], G[image==1], B[image==1] = colours[random.randrange(0,3)]color_mask = np.stack([R,G,B],axis=2)return color_mask
In: # 对mask与bounding box进行可视化def result(img_path, threshold=0.9, rect_th=1, text_size=1, text_th=2):masks, boxes = predict(img_path, threshold)img = cv2.imread(img_path)img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)for i in range(len(masks)):color_mask = color(masks[i])img = cv2.addWeighted(img, 1, color_mask, 0.5, 0)cv2.rectangle(img, boxes[i][0], boxes[i][1], color=(255,0,0), thickness=rect_th)return img
In: from matplotlib import pyplot as pltimg=result('data/demo.jpg')plt.figure(figsize=(10, 10))plt.axis('off')img_result = plt.imshow(img)

TensorBoard界面

上述代码完成了一个简单的实例分割任务。如上图所示,Mask RCNN能够分割出该图像中的部分实例,读者可考虑对预训练模型进行微调,以适应不同场景下的不同任务。注意:上述代码均在CPU上进行,速度较慢,读者可以考虑将数据与模型转移至GPU上,具体操作可以参考第4节。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/312237.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

YOLOv8改进 | 检测头篇 | ASFF改进YOLOv8检测头(全网首发)

一、本文介绍 本文给大家带来的改进机制是利用ASFF改进YOLOv8的检测头形成新的检测头Detect_ASFF,其主要创新是引入了一种自适应的空间特征融合方式,有效地过滤掉冲突信息,从而增强了尺度不变性。经过我的实验验证,修改后的检测头…

影视后期: PR调色处理,调色工具面板介绍

写在前面 整理一些影视后期的相关笔记博文为 Pr 调色处理,涉及调色工具面板简单认知包括 lumetri 颜色和范围面板理解不足小伙伴帮忙指正 元旦快乐哦 _ 名词解释 饱和度 是指色彩的鲜艳程度,也被称为色彩的纯度。具体来说,它表示色相中灰色…

Python装饰器的专业解释

装饰器,其实是用到了闭包的原理来进行操作的。 单个装饰器: 以下是一个简单的例子: def outer(func):print("OUTER enter ...")def wrapper(*args, **kwargs):print("调用之前......")result func(*args, **kwargs)p…

Excel模板填充:从minio上获取模板使用easyExcel填充

最近工作中有个excel导出的功能,要求导出的模板和客户提供的模板一致,而客户提供的模板有着复杂的表头和独特列表风格,像以往使用poi去画是非常耗时间的,比如需要考虑字体大小,单元格合并,单元格的格式等问…

微服务篇之Nacos快速入门

Nacos 简介 Nacos 起源 Nacos 起源于阿里巴巴 2008 年的五彩石项目(完成微服务拆分和业务中台建设),经历了阿里十年双十⼀的洪峰流量的考验,沉淀了简单易用、稳定可靠、性能卓越等核心特性。随着云计算的兴起和受到开源软件行业…

2021-05-08 51单片机74HC164、74LS164、74HCT164、74HC154、74HCT154应用三极管控制继电器

74HC164、74HCT164是8位边沿触发式移位寄存器,串行输入数据,然后并行输出。数据通过两个输入端(DSA或DSB)之一串行输入;任一输入端可以用作高电平使能端,控制另一输入端的数据输入。两个输入端或者连接在一…

文献综述 AI 应用对比 — Elicit, GPTs 与 Perplexity

(注:本文为小报童精选文章,已订阅小报童或加入知识星球「玉树芝兰」用户请勿重复付费) 通过我的这些尝试,你无需再自己去摸索,可以直接根据我展示的结果选择合适的工具,更有效地进行文献回顾。 …

SpringMVC框架

SpringMVC 三层架构MVC模式SpringMVC入门案例总结 三层架构 表现层(web) 页面数据的收集,产出页面 业务逻辑层(service) 业务处理 数据访问层(Dao) 数据持久化 MVC模式 SpringMVC 基于Java…

WorkPlus私有化即时通讯的标杆,助力企业实现信息管控与保障

在信息时代,保护企业的信息安全至关重要。而私有化即时通讯成为了企业提升信息安全的重要手段。作为私有化即时通讯的领先选择,WorkPlus以其卓越的性能和领先的技术,为企业提供了安全可靠的通信解决方案。 私有化即时通讯是企业保护信息安全的…

数据结构:堆的三部曲 (一)堆的实现

堆的实现 1.堆的结构1.1堆的定义理解 2.堆的实现(以小根堆为例)2.1 堆结构体的定义2.2 堆的插入交换函数向上调整算法插入函数的代码 2.3 堆的删除向下调整算法:删除函数的代码: 2.4其他操作 3.测试以及完整源代码实现3.1测试代码…

redis—List列表

目录 前言 1.常见命令 2.使用场景 前言 列表类型是用来存储多个有序的字符串,如图2-19所示,a、b、C、d、e五个元素从左到右组成 了一个有序的列表,列表中的每个字符串称为元素(element) ,一个列表最多可以存储2^32 - 1 个元素…

Spring Cloud Gateway集成Knife4j

1、前提 网关路由能够正常工作。 案例 基于 Spring Cloud Gateway Nacos 实现动态路由拓展的参考地址:Spring Cloud Gateway Nacos 实现动态路由 详细官网案例:https://doc.xiaominfo.com/docs/middleware-sources/spring-cloud-gateway/spring-gatewa…