pytorch异常——RuntimeError:Given groups=1, weight of size..., expected of...

文章目录

    • 省流
    • 异常报错
    • 异常截图
    • 异常代码
    • 原因解释
    • 修正代码
    • 执行结果

省流

  • nn.Conv2d 需要的输入张量格式为 (batch_size, channels, height, width),但您的示例输入张量 x 是 (batch_size, height, width, channels)。因此,需要对输入张量进行转置。

  • 注意,TensorFlow使用"NHWC"(批次、高度、宽度、通道)格式,而PyTorch使用"NCHW"(批次、通道、高度、宽度)格式

异常报错

RuntimeError: Given groups=1, weight of size [16, 3, 2, 3], 
expected input[8, 65, 66, 3] to have 3 channels, 
but got 65 channels instead

异常截图

在这里插入图片描述

异常代码

def down_shifted_conv2d(x , num_filters , filters_size = [2,3],stride = 1, **kwargs):batch_size,H,W,channels = x.shapepadding = (0,0,int(((filters_size[1]) - 1) / 2 ) , int((int(filters_size[1]) - 1) / 2),int(filters_size[0]) - 1 , 0,0,0)x_paded = nn.functional.pad(x, padding)print(x_paded.shape)conv_layer = nn.Conv2d(in_channels=channels, out_channels=num_filters, kernel_size=filters_size,stride=stride, **kwargs)return conv_layer(x_paded)
# Example usage
x = torch.randn(8, 64, 64, 3)  # Example input with batch size 8, height and width 64, and 3 channels
num_filters = 16
output = down_shifted_conv2d(x, num_filters)
print(output.shape)

原因解释

  • 在pytorch中,“nn.Conv2d”需要输入的张量格式为(batch_size,channels,height,width),原图输入的x的格式是(batch_size,height ,weight,channel)所以需要对tensor进行转置。

  • 矩阵交换维度的函数permute,按照编号,将新的顺序填好即可

def down_shifted_conv2d(x , num_filters , filters_size = [2,3], stride = 1, **kwargs):batch_size, H, W, channels = x.shape# Transpose the input tensor to (batch_size, channels, height, width)x = x.permute(0, 3, 1, 2)# Paddingpadding = (int((filters_size[1] - 1) / 2), int((filters_size[1] - 1) / 2),filters_size[0] - 1, 0)x_paded = F.pad(x, padding)

修正代码

def down_shifted_conv2d(x , num_filters , filters_size = [2,3],stride = 1, **kwargs):batch_size,H,W,channels = x.shape# 按照顺序对4个维度分别进行填充padding = (0,0,int(((filters_size[1]) - 1) / 2 ) , int((int(filters_size[1]) - 1) / 2),int(filters_size[0]) - 1 , 0,0,0)x_paded = nn.functional.pad(x, padding)x_paded = x_paded.permute(0,3,1,2)# 进行卷积conv_layer = nn.Conv2d(in_channels=channels, out_channels=num_filters, kernel_size=filters_size,stride=stride, **kwargs)return conv_layer(x_paded)
# Example usage
x = torch.randn(8, 64, 64, 3)  
num_filters = 16
output = down_shifted_conv2d(x, num_filters)
print(output.shape)

执行结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/94763.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第 112 场 LeetCode 双周赛题解

A 判断通过操作能否让字符串相等 I s 1 s1 s1和 s 2 s2 s2第 1 1 1、 2 2 2位若同位置不等,则 s 1 s1 s1交换对应的 i i i和 j j j位置,之后判断 s 1 s1 s1和 s 2 s2 s2是否相当 class Solution { public:bool canBeEqual(string s1, string s2) {for (i…

全脑建模:过去、现在和未来

什么是全脑建模? 全脑建模(WBM)是计算神经科学的子领域,涉及近似全脑神经活动的计算和理论模型。该方法的目标是研究神经活动的宏观时空模式如何由解剖连接结构、内在神经动力学和外部扰动(感觉、认知、药理、电磁等)的相互作用产生。这种宏观现象及其模…

CSS3D+动画

CSS3D 1.css3D 给父元素设置 perspective:景深:近大远小的效果900-1200px这个范围内 transform-style:是否设置3D环境 flat 2D环境 默认值 perserve-3D环境 3D功能函数 1.位移: translateZ()translate3D(x,y,z) <!DOCTYPE html> <html lang"en"><h…

CDH6.3.2集成Kerberos

CDH6.3.2集成Kerberos 一.参考doc CDH enable kerberos: Kerberos Security Artifacts Overview | 6.3.x | Cloudera Documentation CDH disable kerberos:https://www.sameerahmad.net/blog/disable-kerberos-on-CDH; https://community.cloudera.com/t5/Support-Questions…

图像处理简介

目录 基本术语 1 .图像(image) 1.1 像素(Pixel) 1.2 颜色深度&#xff08;Color Depth&#xff09; 1.3 分辨率&#xff08;Resolution&#xff09; 1.4 像素宽高比&#xff08;Pixel Aspect Ratio&#xff09; 1.5 帧率(FPS) 1.6 码率&#xff08;BR&#xff09; 1. …

数据库(MySQL)的存储过程

一、存储过程介绍 存储过程是事先经过编译并存储在数据库中的一段SQL 语句的集合&#xff0c;调用存储过程可以简化应用开发人员的很多工作&#xff0c;减少数据在数据库和应用服务器之间的传输&#xff0c;对于提高数据处理的效率是有好处的。 存储过程思想上很简单&#xff0…

音频——I2S 右对齐模式(四)

I2S 基本概念飞利浦(I2S)标准模式左(MSB)对齐标准模式右(LSB)对齐标准模式DSP 模式TDM 模式 文章目录 I2S right时序图逻辑分析仪抓包 I2S right I2S 右对齐标准 也叫日本格式&#xff0c;sony 格式。相比于标准左对齐格式&#xff0c;标准右对齐的不足在于接收设备必须事先知…

53 个 CSS 特效 3(完)

53 个 CSS 特效 3&#xff08;完&#xff09; 前两篇地址&#xff1a; 53 个 CSS 特效 153 个 CSS 特效 2 这里是第 33 到 53 个&#xff0c;很多内容都挺重复的&#xff0c;所以这里解释没之前的细&#xff0c;如果漏了一些之前的笔记会补一下&#xff0c;写过的就会跳过。…

11 - 深入了解NIO的优化实现原理

Tomcat 中经常被提到的一个调优就是修改线程的 I/O 模型。Tomcat 8.5 版本之前&#xff0c;默认情况下使用的是 BIO 线程模型&#xff0c;如果在高负载、高并发的场景下&#xff0c;可以通过设置 NIO 线程模型&#xff0c;来提高系统的网络通信性能。 我们可以通过一个性能对比…

STM32+RTThread配置以太网无法ping通,无法获取动态ip的问题

记录一个非常蠢的问题&#xff0c;今天在移植rtthread的以太网驱动的时候出现无法获取动态ip的问题&#xff0c;问题如下&#xff1a; 设置为动态ip时不管是连接路由器还是电脑主机都无法ping通&#xff0c;也无法获取dns地址。 设置为静态ip时无法ping通主机。 使用wireshark…

python自动化测试-自动化基本技术原理

1 概述 在之前的文章里面提到过&#xff1a;做自动化的首要本领就是要会 透过现象看本质 &#xff0c;落实到实际的IT工作中就是 透过界面看数据。 掌握上面的这样的本领可不是容易的事情&#xff0c;必须要有扎实的计算机理论基础&#xff0c;才能看到深层次的本质东西。 …

YOLOv7框架解析

YOLOv7概念 YOLOv7是基于YOLO系列的目标检测算法&#xff0c;由Ultra-Light-Fast-Detection&#xff08;ULFD&#xff09;和Scaled-YOLOv4两种算法结合而来。它是一种高效、准确的目标检测算法&#xff0c;具有以下特点&#xff1a; 1. 高效&#xff1a;YOLOv7在保持准确率的…