[PyTorch]即插即用的热力图生成

        先上张效果图,本来打算移植霹雳老师的使用Pytorch实现Grad-CAM并绘制热力图。但是看了下代码,需要骨干网络按照标准写法(即将特征层封装为features数组),而我写的网络图省事并没有进行封装,改造网络的代价又太大了,所以干脆直接重写一个。

一、生成热力图

        大致可以分为三步:①读取图片;②前向传递运算;③用特征向量生成特征图。而图片的resize图简单可以直接用transforms,后面反正也是直接resize回来的,并不会造成变形。

# 加载一个transforms用于变形,input_shape为预设的图像尺寸
transform = transforms.Compose([transforms.Resize((input_shape[0],input_shape[1])),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])
image = Image.open(image_path)     #image_path为文件路径
input_tensor = transform(image)    #将图片转换为tensor类型
input_batch = input_tensor.unsqueeze(0)    #为tensor添加batch维度# 前向传递
model.eval()
with torch.no_grad():output = model(input_batch)

        使用特征图生成热力图的原理是:将该维度上所有的tensor进行叠加,然后将生成的矩阵变形回输入向量的尺寸

heatmap = torch.sum(output, dim=1)    #所有通道求和
max_value = torch.max(heatmap)
min_value = torch.min(heatmap)
heatmap = (heatmap-min_value)/(max_value-min_value)*255heatmap = heatmap.cpu().numpy().astype(np.uint8).transpose(1,2,0)  # 提取热力图heatmap = cv2.resize(heatmap, input_shape,interpolation=cv2.INTER_LINEAR)  # 还原尺寸# 将矩阵转换为image类
heatmap=cv2.applyColorMap(heatmap,cv2.COLORMAP_JET)
heatimg = Image.fromarray(heatmap)

二、叠加原图

        直接使用plt进行叠加!

    # 将热力图叠加到原图上org_size = image.sizeheatimg = heatimg.resize(org_size)    #将热力图变回输入图像的尺寸plt.axis('off')plt.imshow(image)plt.imshow(heatimg, alpha=0.5)  # alpha为热力图的透明度# 显示叠加后的图形plt.show()

三、总结

        这段代码和霹雳老师的Grad-CAM对比优劣都很明显,优点是代码比较简单。上可以通过插入前向传递的环境直接得到任何层的热力图。但缺点就是不能关注特定的类别,且生成的热力图也不是很美观。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/140859.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UML类图关系(泛化 、继承、实现、依赖、关联、聚合、组合)

在UML类图中,常见的有以下几种关系: 泛化(Generalization), 实现(Realization),关联(Association),聚合(Aggregation),组合(Composition)&#x…

视频剪辑SDK,实现高效的移动端视频编辑

为了满足企业对视频编辑的需求,美摄提供了iOS/Android端视频编辑SDK技术开发服务,帮助企业快速高效地制作高质量视频。本文将详细介绍美摄的视频编辑SDK的优势和特点,以及如何为企业提供技术解决方案。 随着智能手机的普及和移动互联网的发展…

分享一个基于Python+Django的高校食堂外卖点餐系统的设计实现(源码、调试、开题、lw、ppt)

💕💕作者:计算机源码社 💕💕个人简介:本人七年开发经验,擅长Java、Python、PHP、.NET、微信小程序、爬虫、大数据等,大家有这一块的问题可以一起交流! 💕&…

ESP32集成开发环境Espressif-IDE安装 – Windows

陈拓 2023/10/15-2023/10/16 1. 概述 Espressif IDE是一个基于Eclipse CDT的集成开发环境(IDE),用于使用ESP-IDF框架开发物联网应用程序。这是一个专门为ESP-IDF构建的独立定制IDE。Espressif IDE附带了IDF Eclipse插件、重要的Eclipse CDT插…

深度学习 | CNN卷积核与通道

10.1、单通道卷积 以单通道卷积为例,输入为(1,5,5),分别表示1个通道,宽为5,高为5。 假设卷积核大小为3x3,padding0,stride1。 运算过程: 不断的在图像上进行遍历&#…

arrow(c++)改写empyrical系列1---用arrow读取基金净值数据并计算夏普率

用arrow c版本读取了csv中的基金净值数据,然后计算了夏普率,比较尴尬的是,arrow c版本计算耗费的时间却比python的empyrical版本耗费时间多。。。 arrow新手上路,第一次自己去实现功能,实现的大概率并不是最高效的方…

java击球小游戏运行代码

创建一个图形化的小游戏通常需要使用Java图形库,例如Swing或JavaFX。下面是一个使用JavaFX创建的简单的图形化小游戏示例,其中一个小球会在窗口内移动,你需要点击小球以增加得分: import javafx.application.Application; import…

Win10系统开机启动文件夹在哪里找?

Win10系统开机启动文件夹在哪里找?Win10系统开机启动文件夹是一个非常重要的目录,它决定了电脑在开机的时候,会有哪些应用程序是自动启动。但是,很多新手用户不知道Win10电脑内开机启动文件夹的具体位置,下面小编介绍开…

ms-sql server sql 把逗号分隔的字符串分开

案例: sql 查询-字段里是逗号,分隔开的数组,查询匹配数据 sql 查询-字段里是逗号,分隔开的数组,查询匹配数据_sql server 数组匹配-CSDN博客 SQL SERVER 把逗号隔开的字符串拆分成行 SQL SERVER 把逗号隔开的字符串拆分成行_sqlserver拆分…

索引背后的数据结构——B+树

为什么要使用B树? 可以进行数据查询的数据结构有二叉搜索树、哈希表等。对于前者来说,树的高度越高,进行查询比较的时候访问磁盘的次数就越多。而后者只有在数据等于key值的时候才能进行查询,不能进行模糊匹配。所以出现了B树来解…

Python+unittest接口自动化测试

首先配置好开发环境,下载安装Python并下载安装pycharm,在pycharm中创建项目功能目录。以下是项目的目录结构: common: 1 2 3 4 5 6 7 8 9 ——configDb.py:这个文件主要编写数据库连接池的相关内容,本项目…

【C++11】右值引用、移动构造、移动赋值、完美转发 的原理介绍

文章目录 一、概念1.1 左值1.2 左值引用1.3 什么是右值?1.4 什么是右值引用?对于参数左值还是右值的不同,是被重载支持的左值引用的使用场景 和 缺陷 二、移动语义2.1 移动拷贝构造2.2 移动赋值 三、右值引用 与 STL3.1 移动拷贝构造 和 赋值…