TVRNet网络PyTorch实现

文章目录

    • 文章地址
    • 网络各层结构
    • 代码实现

文章地址

  • An End-to-End Traffic Visibility Regression Algorithm
  • 文章通过训练搜集得到的真实道路图像数据集(Actual Road dense image Dataset, ARD),通过专业的能见度计和多人标注,获得可靠的能见度标签数据集。构建网络,进行训练,获得了较好的能见度识别网络。网络包括特征提取​、多尺度映射​、特征融合​、非线性输出(回归范围为[0,1],需要经过(0,0),(1,1)改用修改的sigmoid函数,相较于ReLU更好)。结构如下​
    在这里插入图片描述

网络各层结构

在这里插入图片描述

  • 我认为红框位置与之相应的参数不匹配,在Feature Extraction部分Reshape之后得到的特征图大小为4124124。紧接着接了一个卷积层Conv,显示输入是3128128
  • 第二处红框,MaxPool的kernel设置为88,特征图没有进行padding,到全连接层的输入变为64117*117,参数不对应
    在这里插入图片描述

代码实现

"""Based on the ideas of the below paper, using PyTorch to build TVRNet.Reference: Qin H, Qin H. An end-to-end traffic visibility regression algorithm[J]. IEEE Access, 2021, 10: 25448-25454.​@weishuo
"""import torch
from torch import nn
import mathclass Inception(nn.Module):def __init__(self, in_planes, out_planes):super(Inception, self).__init__()self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, padding=0)self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1)self.conv5 = nn.Conv2d(in_planes, out_planes, kernel_size=5, padding=2)self.conv7 = nn.Conv2d(in_planes, out_planes, kernel_size=7, padding=3)def forward(self, x):out_1 = self.conv1(x)out_3 = self.conv3(x)out_5 = self.conv5(x)out_7 = self.conv7(x)out = torch.cat((out_1, out_3, out_5, out_7), dim=1)return outdef modify_sigmoid(x):return 1 / (1 + torch.exp(-10*(x-0.5)))class TVRNet(nn.Module):def __init__(self, in_planes, out_planes):super(TVRNet, self).__init__()# (B, 3, 224, 224)  ——>  (B, 3, 220, 220)self.FeatureExtraction_onestep = nn.Sequential(nn.Conv2d(in_planes, 20, kernel_size=5, padding=0),nn.ReLU(inplace=True),)self.FeatureExtraction_maxpool = nn.MaxPool2d((5, 1))self.MultiScaleMapping = nn.Sequential(Inception(4, 16),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=8))self.FeatureIntegration = nn.Sequential(nn.Linear(46656, 100),nn.ReLU(inplace=True),nn.Dropout(0.4),nn.Linear(100, out_planes))self.NonLinearRegression = modify_sigmoiddef forward(self, x):x = self.FeatureExtraction_onestep(x)x = x.view((x.shape[0], 1, x.shape[1], -1))x = self.FeatureExtraction_maxpool(x)x = x.view(x.shape[0], x.shape[2], int(math.sqrt(x.shape[3])), int(math.sqrt(x.shape[3])))# print(x.shape)x = self.MultiScaleMapping(x)# print(x.shape)x = x.view(x.shape[0], -1)x = self.FeatureIntegration(x)out = self.NonLinearRegression(x)return outif __name__ == '__main__':a = torch.randn(1,3,224,224)net = TVRNet(3,3)b = net(a)print(b.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/151980.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

javascript数据类型

目录 原始数据类型 引用数据类型 类型检测 类型转换 总结 原始数据类型 JavaScript 中有六种原始数据类型,它们是: Undefined(未定义): 表示一个未被赋值的变量。Null(空值): 表示一个空对象指针。B…

Qt 实现侧边栏滑出菜单效果

1.效果图 2.实现原理 这里做了两个widget,一个是 展示底图widget,一个是 展示动画widget。 这两个widget需要重合。动画widget需要设置属性叠加到底图widget上面,设置如下属性: setWindowFlags(Qt::FramelessWindowHint | Qt::…

【优选算法系列】第二节.双指针(202. 快乐数和11. 盛最多水的容器)

作者简介:大家好,我是未央; 博客首页:未央.303 系列专栏:优选算法系列 每日一句:人的一生,可以有所作为的时机只有一次,那就是现在!!!&#xff01…

基于java+springboot的人事招聘信息网站

运行环境 开发语言:Java 框架:springboot JDK版本:JDK1.8 服务器:tomcat7 数据库:mysql 数据库工具:Navicat11 开发软件:eclipse/myeclipse/idea Maven包:Maven 项目介绍 开发过程…

锐捷EG易网关 phpinfo.view.php 信息泄露

致未经身份验证获取敏感信息 访问漏洞url: /tool/view/phpinfo.view.php漏洞证明: 文笔生疏,措辞浅薄,望各位大佬不吝赐教,万分感谢。 免责声明:由于传播或利用此文所提供的信息、技术或方法而造成的任何…

【强化学习】10 —— DQN算法

文章目录 深度强化学习价值和策略近似RL与DL结合产生的问题深度强化学习的分类 Q-learning回顾深度Q网络(DQN)经验回放优先经验回放 目标网络算法流程 代码实践CartPole环境代码结果 参考 深度强化学习 价值和策略近似 我们可以利用深度神经网络建立这些…

设计模式(19)命令模式

一、介绍: 1、定义:命令模式(Command Pattern)是一种行为设计模式,它将请求封装为一个对象,从而使你可以使用不同的请求对客户端进行参数化。命令模式还支持请求的排队、记录日志、撤销操作等功能。 2、组…

Python深度学习实战-基于tensorflow原生代码搭建BP神经网络实现分类任务(附源码和实现效果)

实现功能 前面两篇文章分别介绍了两种搭建神经网络模型的方法,一种是基于tensorflow的keras框架,另一种是继承父类自定义class类,本篇文章将编写原生代码搭建BP神经网络。 实现代码 import tensorflow as tf from sklearn.datasets import…

哈希算法:如何防止数据库中的用户信息被脱库?

文章来源于极客时间前google工程师−王争专栏。 2011年CSDN“脱库”事件,CSDN网站被黑客攻击,超过600万用户的注册邮箱和密码明文被泄露,很多网友对CSDN明文保存用户密码行为产生了不满。如果你是CSDN的一名工程师,你会如何存储用…

debian 10 安装apache2 zabbix

nginx 可以略过,改为apache2 apt updateapt-get install nginx -ynginx -v nginx version: nginx/1.14.2mysql 安装参考linux debian10 安装mysql5.7_debian apt install mysql5.7-CSDN博客 Install and configure Zabbix for your platform a. Install Zabbix re…

SpringCore完整学习教程5,入门级别

本章从第6章开始 6. JSON Spring Boot提供了三个JSON映射库的集成: Gson Jackson JSON-B Jackson是首选的和默认的库。 6.1. Jackson 为Jackson提供了自动配置,Jackson是spring-boot-starter-json的一部分。当Jackson在类路径上时,将自动配置Obj…

uniapp 中添加 vconsole

uniapp 中添加 vconsole 一、安装 vconsole npm i vconsole二、使用 vconsole 在项目的 main.js 文件中添加如下内容 // #ifdef H5 // 提交前需要注释 本地调试使用 import * as vconsole from "vconsole"; new vconsole() // 使用 vconsole // #endif三、成功