Variations-of-SFANet-for-Crowd-Counting可视化代码

前文对Variations-of-SFANet-for-Crowd-Counting做了一点基础梳理,链接如下:Variations-of-SFANet-for-Crowd-Counting记录-CSDN博客

本次对其中两个可视化代码进行梳理

1.Visualization_ShanghaiTech.ipynb

不太习惯用jupyter notebook, 这里改成了python代码测试,下面代码提到的测试数据都是项目自带的,权重自己下载一下吧,前文提到了一些需要下载的权重或者数据。

import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
from matplotlib import cm as CMimport os
import numpy as np
from scipy.io import loadmat
from PIL import Image; import cv2
import torch
from torchvision import transforms
from models import M_SFANet
part = 'B'; index = 4
DATA_PATH = f"./ShanghaiTech_Crowd_Counting_Dataset/part_{part}_final/test_data/"
fname = os.path.join(DATA_PATH, "ground_truth", f"GT_IMG_{index}.mat")
img = Image.open(os.path.join(DATA_PATH, "images", f"IMG_{index}.jpg")).convert('RGB')
plt.imshow(img)
plt.gca().set_axis_off()
plt.show()
gt = loadmat(fname)["image_info"]
location = gt[0, 0][0, 0][0]
count = location.shape[0]
print(fname)
print('label:', count)
model = M_SFANet.Model()
model.load_state_dict(torch.load(f"./ShanghaitechWeights/checkpoint_best_MSFANet_{part}.pth", map_location=torch.device('cpu'))["model"]);
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])height, width = img.size[1], img.size[0]
height = round(height / 16) * 16
width = round(width / 16) * 16
img = cv2.resize(np.array(img), (width,height), Image.BILINEAR)
img = trans(Image.fromarray(img))[None, :]
model.eval()
density_map, attention_map = model(img)
print('Estimated count:', torch.sum(density_map).item())
print("Visualize estimated density map")
plt.gca().set_axis_off()
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(density_map[0][0].detach().numpy(), cmap = CM.jet)
# plt.savefig(fname=..., dpi=300)
plt.show()

运行结果如下,还有两张可视化的图

上面这样看是不是不太直观,下面这张图够直观

2.Visualization_UCF-QNRF.ipynb

同上改成了python代码测试

import torch
import os
import numpy as np
from datasets.crowd import Crowd
from models.vgg import vgg19
import argparse
from PIL import Image
import cv2
import sys
# sys.path.insert(0, '/home/pongpisit/CSRNet_keras/')
from models import M_SegNet_UCF_QNRF
from matplotlib import pyplot as plt
from matplotlib import cm as CM
datasets = Crowd(os.path.join('/home/pongpisit/CSRNet_keras/CSRNet-keras/wnet_playground/W-Net-Keras/data/UCF-QNRF_ECCV18/processed/', 'test'), 512, 8, is_gray=False, method='val')
dataloader = torch.utils.data.DataLoader(datasets, 1, shuffle=False,num_workers=8, pin_memory=False)
model = M_SegNet_UCF_QNRF.Model()
device = torch.device('cuda')
model.to(device)
# model.load_state_dict(torch.load(os.path.join('./u_logs/0331-111426/', 'best_model.pth'), device))
model.load_state_dict(torch.load(os.path.join('./seg_logs/0327-172121/', 'best_model.pth'), device))
model.eval()epoch_minus = []
preds = []
gts = []for inputs, count, name in dataloader:inputs = inputs.to(device)assert inputs.size(0) == 1, 'the batch size should equal to 1'with torch.set_grad_enabled(False):outputs = model(inputs)temp_minu = count[0].item() - (torch.sum(outputs).item())preds.append(torch.sum(outputs).item())gts.append(count[0].item())print(name, temp_minu, count[0].item(), torch.sum(outputs).item())epoch_minus.append(temp_minu)epoch_minus = np.array(epoch_minus)
mse = np.sqrt(np.mean(np.square(epoch_minus)))
mae = np.mean(np.abs(epoch_minus))
log_str = 'Final Test: mae {}, mse {}'.format(mae, mse)
print(log_str)
met = []
for i in range(len(preds)):met.append(100 * np.abs(preds[i] - gts[i]) / gts[i])idxs = []
for i in range(len(met)):idxs.append(np.argmin(met))if len(idxs) == 5: breakmet[np.argmin(met)] += 100000000
print(set(idxs))
def resize(density_map, image):density_map = 255*density_map/np.max(density_map)density_map= density_map[0][0]image= image[0]print(density_map.shape)result_img = np.zeros((density_map.shape[0]*2, density_map.shape[1]*2))for i in range(result_img.shape[0]):for j in range(result_img.shape[1]):result_img[i][j] = density_map[int(i / 2)][int(j / 2)] / 4result_img  = result_img.astype(np.uint8, copy=False)return result_imgdef vis_densitymap(o, den, cc, img_path):fig=plt.figure()columns = 2rows = 1
#     X = np.transpose(o, (1, 2, 0))X = osumm = int(np.sum(den))den = resize(den, o)for i in range(1, columns*rows +1):# image plotif i == 1:img = Xfig.add_subplot(rows, columns, i)plt.gca().set_axis_off()plt.margins(0,0)plt.gca().xaxis.set_major_locator(plt.NullLocator())plt.gca().yaxis.set_major_locator(plt.NullLocator())plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)plt.imshow(img)# Density plotif i == 2:img = denfig.add_subplot(rows, columns, i)plt.gca().set_axis_off()plt.margins(0,0)plt.gca().xaxis.set_major_locator(plt.NullLocator())plt.gca().yaxis.set_major_locator(plt.NullLocator())plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)plt.text(1, 80, 'M-SegNet* Est: '+str(summ)+', Gt:'+str(cc), fontsize=7, weight="bold", color = 'w')plt.imshow(img, cmap=CM.jet)filename = img_path.split('/')[-1]filename = filename.replace('.jpg', '_heatpmap.png')print('Save at', filename)plt.savefig('seg_'+filename, transparent=True, bbox_inches='tight', pad_inches=0.0, dpi=200)processed_dir = '/home/pongpisit/CSRNet_keras/CSRNet-keras/wnet_playground/W-Net-Keras/data/UCF-QNRF_ECCV18/processed/test/'model.eval()c = 0for inputs, count, name in dataloader:img_path = os.path.join(processed_dir, name[0]) + '.jpg'if c in set(idxs):inputs = inputs.to(device)with torch.set_grad_enabled(False):outputs = model(inputs)img = Image.open(img_path).convert('RGB')height, width = img.size[1], img.size[0]height = round(height / 16) * 16width = round(width / 16) * 16img = cv2.resize(np.array(img), (width,height), cv2.INTER_CUBIC)print('Do VIS')vis_densitymap(img, outputs.cpu().detach().numpy(), int(count.item()), img_path)c += 1        else:c += 1

但是该代码要用UCF-QNRF_ECCV18数据集,官网的太慢了,给个靠谱的链接:UCF-QNRF_数据集-阿里云天池

下载下来,然后利用bayesian_preprocess_sh.py这个代码处理一下就可以用于上述代码了,注意一下UCF-QNRF_ECCV18的mat文件中点坐标的读取代码有点问题,自己输出一下mat文件信息就看得出来了。输出文件夹中会有相应的jpg和npy文件。

运行可视化代码,这期间遇到了一个报错

ImportError: cannot import name 'COMMON_SAFE_ASCII_CHARACTERS' from 'charset_normalizer.constant' (C:\Anaconda3\lib\site-packages\charset_normalizer\constant.py)

邪门解决方案,安装一个chardet

pip install chardet -i https://pypi.tuna.tsinghua.edu.cn/simple

要是上述方法还不好使就换一个,更新一下charset_normalizer,或者卸载重装charset_normalizer

pip install --upgrade charset-normalizer

要是出现如下报错

RuntimeError:An attempt has been made to start a new process before thecurrent process has finished its bootstrapping phase.This probably means that you are not using fork to start yourchild processes and you have forgotten to use the proper idiomin the main module:if __name__ == '__main__':freeze_support()...The "freeze_support()" line can be omitted if the programis not going to be frozen to produce an executable.

把代码中的num_workers改成0,跑起来结果如下

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/156954.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

人工智能AI 全栈体系(十一)

第一章 神经网络是如何实现的 这些神经网络越来越复杂,都是用BP算法求解。网络有些变化就可能需要重新推导,而在实验过程中可能会做很多尝试,这样每次都重新推导BP算法太麻烦了。 十、深度学习框架 现在有了很多深度学习框架,这…

AQS面试题总结

一:线程等待唤醒的实现方法 方式一:使用Object中的wait()方法让线程等待,使用Object中的notify()方法唤醒线程 必须都在synchronized同步代码块内使用,调用wait,notify是锁定的对象; notify必须在wait后执…

QT5交叉编译保姆级教程(arm64、mips64)

什么是交叉编译? 简单说,就是在当前系统平台上,开发编译运行于其它平台的程序。 比如本文硬件环境是x86平台,但是编译出来的程序是在arm64架构、mips64等架构上运行 本文使用的操作系统:统信UOS家庭版22.0 一、安装…

由QTableView/QTableWidget显示进度条和按钮,理解qt代理delegate用法

背景: 我的最初应用场景,就是要在表格上用进度条显示数据,以及放一个按钮。 qt-creator中有自带的delegate示例可以参考,但终归自己动手还是需要理解细节,否则不能随心所欲。 自认没那个天赋,于是记录下…

8. 一文快速学懂常用工具——Linux命令(上)

本章讲解知识点 引言 指令学习 本专栏适合于软件开发刚入职的学生或人士,有一定的编程基础,帮助大家快速掌握工作中必会的工具和指令。本专栏针对面试题答案进行了优化,尽量做到好记、言简意赅。如专栏内容有错漏,欢迎在评论区指…

瑞数专题五

今日文案:焦虑,想象力过度发酵的产物。 网址:https://www.iyiou.com/ 专题五主要是分享瑞数6代。6代很少见,所以找理想哥要的,感谢感谢。 关于瑞数作者之前已经分享过4篇文章,全都收录在瑞数专栏中了&am…

【计算机网络】浏览器的通信能力

1. 用户代理 浏览器可以代替用户完成http请求,代替用户解析响应结果,所以我们称之为用户代理 user agent。 浏览器两大核心能力: 自动发送请求的能力自动解析响应的能力 1.1 自动发送请求的能力 用户在地址栏输入了一个url地址&#xff0…

HarmonyOS数据管理与应用数据持久化(一)

一. 数据管理概述 功能介绍 数据管理为开发者提供数据存储、数据管理能力,比如联系人应用数据可以保存到数据库中,提供数据库的安全、可靠等管理机制。 数据存储:提供通用数据持久化能力,根据数据特点,分为用户首选项、…

逻辑分析仪与示波器选择

一、简介 逻辑分析仪是利用时钟从测试设备上采集和显示数字信号的仪器,最主要的作用在于时序判定。逻辑分析仪与示波器不同,它不能显示连续的模拟量波形,而只显示高低两种电平状态(逻辑1和0)。在设置了参考电压后&…

css中flexbox和grid的区别

css中flexbox和grid的区别 我们是不是被那些不会按预期排列的元素所影响?这篇文章我们将深入探讨css中flexbox和grid的布局。通过了解他们的主要差异,我们会发现这些布局是如何改变我们网站的风格。 理解CSS布局 css布局是网页设计的一个重要方面&…

【广州华锐互动】VR特警作战模拟演练系统

在科技发展的驱动下,各行各业都在寻找新的方式来提升效率和培训质量。其中,虚拟现实(VR)技术在各个领域都有广泛的应用,包括警察培训。VR特警作战模拟演练系统由VR公司广州华锐互动开发,它使用虚拟现实环境…

高效文件整理:按数量划分自动建立文件夹,轻松管理海量文件

在日常生活和工作中,我们经常需要处理大量的文件。然而,如何高效地整理这些文件却是一个棘手的问题。有时候,我们可能需要按照特定的规则来建立文件夹,以便更高效地整理文件。例如,您可以按照日期、时间或者特定的标签…