Pytorch中张量矩阵乘法函数(mm, bmm, matmul)使用说明,含高维张量实例及运行结果

Pytorch中张量矩阵乘法函数使用说明

  • 1 torch.mm() 函数
    • 1.1 torch.mm() 函数定义及参数
    • 1.2 torch.bmm() 官方示例
  • 2 torch.bmm() 函数
    • 2.1 torch.bmm() 函数定义及参数
    • 2.2 torch.bmm() 官方示例
  • 3 torch.matmul() 函数
    • 3.1 torch.matmul() 函数定义及参数
    • 3.2 torch.matmul() 规则约定
    • 3.3 torch.matmul() 官方示例
    • 3.4 高维数据实例解释
  • 参考博文及感谢

1 torch.mm() 函数

全称为matrix-matrix product,对输入的张量做矩阵乘法运算,输入输出维度一定是2维

1.1 torch.mm() 函数定义及参数

torch.bmm(input, mat2, , out=None) → Tensor
input (Tensor) – – 第一个要相乘的矩阵
** mat2
* (Tensor) – – 第二个要相乘的矩阵
不支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

1.2 torch.bmm() 官方示例

mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
torch.mm(mat1, mat2)tensor([[ 0.4851,  0.5037, -0.3633],[-0.0760, -3.6705,  2.4784]])

2 torch.bmm() 函数

全称为batch matrix-matrix product,对输入的张量做矩阵乘法运算,输入输出维度一定是3维;

2.1 torch.bmm() 函数定义及参数

torch.bmm(input, mat2, , out=None) → Tensor
input (Tensor) – – 第一批要相乘的矩阵
** mat2
* (Tensor) – – 第二批要相乘的矩阵
不支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

2.2 torch.bmm() 官方示例

input = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(input, mat2)
res.size()torch.Size([10, 3, 5])

3 torch.matmul() 函数

可进行多维矩阵运算,根据不同输入维度进行广播机制然后运算,和点积类似,广播机制可参考之前博文torch.mul()函数。

3.1 torch.matmul() 函数定义及参数

torch.matmul(input, mat2, , out=None) → Tensor
input (Tensor) – – 第一个要相乘的张量
** mat2
* (Tensor) – – 第二个要相乘的张量
支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

3.2 torch.matmul() 规则约定

(1)若两个都是1D(向量)的,则返回两个向量的点积;

(2)若两个都是2D(矩阵)的,则按照(矩阵相乘)规则返回2D;

(3)若input维度1D,other维度2D,则先将1D的维度扩充到2D(1D的维数前面+1),然后得到结果后再将此维度去掉,得到的与input的维度相同。即使作扩充(广播)处理,input的维度也要和other维度做对应关系;

(4)若input是2D,other是1D,则返回两者的点积结果;

(5)如果一个维度至少是1D,另外一个大于2D,则返回的是一个批矩阵乘法( a batched matrix multiply)

  • (a)若input是1D,other是大于2D的,则类似于规则(3);
  • (b)若other是1D,input是大于2D的,则类似于规则(4);
  • (c)若input和other都是3D的,则与torch.bmm()函数功能一样;
  • (d)如果input中某一维度满足可以广播(扩充),那么也是可以进行相乘操作的。例如 input(j,1,n,m)* other (k,m,p) = output(j,k,n,p)

matmul() 根据输入矩阵自动决定如何相乘。低维根据高维需求,合理广播。

3.3 torch.matmul() 官方示例

# vector x vector
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)
torch.matmul(tensor1, tensor2).size()torch.Size([])
# matrix x vector
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()torch.Size([3])
# batched matrix x broadcasted vector
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()torch.Size([10, 3])
# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()torch.Size([10, 3, 5])
# batched matrix x broadcasted matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
torch.matmul(tensor1, tensor2).size()torch.Size([10, 3, 5])

3.4 高维数据实例解释

直接看一个4维的二值例子,先看图(红虚线和实线是为了便于区分维度而添加),不懂再结合代码和结果分析,先做广播,然后对应矩阵进行乘积运算
在这里插入图片描述

代码如下:

import torch
import numpy as npnp.random.seed(2022)
a = np.random.randint(low=0, high=2, size=(2, 2, 3, 4))
a = torch.tensor(a)
b = np.random.randint(low=0, high=2, size=(2, 1, 4, 3))
b = torch.tensor(b)
c = torch.matmul(a, b)
# or
# c = a @ b
print(a)
print("=============================================")
print(b)
print("=============================================")
print(c.size())
print("=============================================")
print(c)

运行结果为:

tensor([[[[1, 0, 1, 0],[1, 1, 0, 1],[0, 0, 0, 0]],[[1, 1, 1, 1],[1, 1, 0, 0],[0, 1, 0, 1]]],[[[0, 0, 0, 1],[0, 0, 0, 1],[0, 1, 0, 0]],[[1, 1, 1, 1],[1, 1, 1, 1],[0, 0, 0, 0]]]], dtype=torch.int32)
=============================================
tensor([[[[0, 1, 0],[1, 1, 0],[0, 0, 0],[1, 1, 0]]],[[[0, 1, 0],[1, 1, 1],[1, 1, 1],[1, 0, 1]]]], dtype=torch.int32)
=============================================
torch.Size([2, 2, 3, 3])
=============================================
tensor([[[[0, 1, 0],[2, 3, 0],[0, 0, 0]],[[2, 3, 0],[1, 2, 0],[2, 2, 0]]],[[[1, 0, 1],[1, 0, 1],[1, 1, 1]],[[3, 3, 3],[3, 3, 3],[0, 0, 0]]]], dtype=torch.int32)

参考博文及感谢

部分内容参考以下链接,这里表示感谢 Thanks♪(・ω・)ノ
参考博文1 官方文档查询地址
https://pytorch.org/docs/stable/index.html
参考博文2 Pytorch矩阵乘法之torch.mul() 、 torch.mm() 及torch.matmul()的区别
https://blog.csdn.net/irober/article/details/113686080

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/109004.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Element树形控件使用过程中遇到的问题及解决方法

1.需求1点击编辑按钮&#xff0c;出现修改组织弹窗&#xff0c;且将点击时的组织名称返现在输入框中。 思路是点击编辑按钮&#xff0c;取到节点点击时返回的data信息中的label进行赋值即可。 <el-treestyle"margin-top: 20px":data"organizationTreeData…

2023/9/14 -- C++/QT

作业&#xff1a; 仿照Vector实现MyVector&#xff0c;最主要实现二倍扩容 #include <iostream>using namespace std;template <typename T> class MyVector { private:T *data;size_t size;size_t V_capacity; public://无参构造MyVector():data(nullptr),size(…

【Java 基础篇】深入了解Java中的键值对集合:Map集合详解

Map是Java中常用的数据结构之一&#xff0c;用于存储键值对&#xff08;Key-Value&#xff09;映射。它提供了快速的查找和访问能力&#xff0c;是编程中常用的工具之一。本文将深入介绍Java中的Map集合&#xff0c;包括常见的Map实现类、基本操作、使用示例以及一些重要的注意…

MATLAB入门-矩阵的运算

MATLAB入门-矩阵的运算 本篇文章为学习笔记&#xff0c;课程链接为&#xff1a;头歌 相关知识 常见的矩阵运算有算术运算、关系运算和逻辑运算。MATLAB中的所有变量都是以矩阵的形式存储的&#xff0c;单个变量就相当于一个1*1的矩阵。 算术运算 下面展示的是常见的矩阵之…

【绝㊙️】三年开发内功心得

经典嵌套if-else问题 这个也是老生常谈问题了&#xff0c;不管哪里都能看到。 那如何解决 方法一&#xff08;重要&#xff09;&#xff1a; 如果逻辑分支过多&#xff0c; 即使你不解决嵌套if-slse&#xff0c;至少也要把每个 if的{}里的逻辑抽到一个独立的方法或者工具类…

Gin路由中间件详解

什么是中间件 Gin 中的中间件必须是一个 gin.HandlerFunc 类型,配置路由的时候可以传递多个 func 回调函 数, 最后一个 func 回调函数前面触发的方法 都可以称为中间件。 中间件操作演示 方法一: 直接写在func,回调函数内 r.GET("/middle",func(ctx *gin.Cont…

Tokenview X-ray功能:深入探索EVM系列浏览器的全新视角

Tokenview作为一家领先的多链区块浏览器&#xff0c;为了进一步优化区块链用户的使用体验&#xff0c;我们推出了X-ray&#xff08;余额透视&#xff09;功能。该功能将帮助您深入了解EVM系列浏览器上每个地址的交易过程&#xff0c;以一种直观、简洁的方式呈现地址的进出账情况…

C# 嵌套循环

例子说明 循环遍历xml文件中的信息包括&#xff1a;节点名称&#xff08;一个&#xff09;&#xff0c;节点的串联值&#xff08;一个&#xff09;&#xff0c;节点的属性&#xff08;多个&#xff09; Xml文件 <?xml version"1.0" encoding"utf-8" …

笔记(一)斯坦福CS224W图机器学习、图神经网络、知识图谱

节点和连接构成的图 如何对图数据进行挖掘&#xff1f; 传统机器学习&#xff0c;数据是独立同分布的&#xff0c;解决表格、矩阵、序列等问题 图机器学习处理连接的数据&#xff0c;需要满足以下几个方面&#xff1a; 1、图是任意尺寸输入2、图是动态变化的&#xff0c;有时…

腾讯mini项目-【指标监控服务重构】2023-07-27

今日已办 SigNoz Log Management SigNoz原生支持 OpenTelemetry 来收集日志&#xff0c;SigNoz 在收集器端进行了优化&#xff0c;为SigNoz中的日志添加了不同的功能。 OpenTelemetry 提供了各种接收器和处理器&#xff0c;用于直接通过 OpenTelemetry Collector 或通过 Flue…

里氏替换原则~

里氏替换原则&#xff08;Liskov Substitution Principle&#xff09;是面向对象设计中的一个基本原则&#xff0c;它是由Barbara Liskov提出的。 如果对于每一个类型为Apple的对象1&#xff0c;都有类型为fruit的对象2&#xff0c;使得以fruit定义的所有程序 P 在所有的对象1都…

基于Qt4开发曲线绘制交互软件Plotter

目前市面上有很多曲线绘制软件,但其交互功能较差。比如,想要实现数据的交互,同步联动等,都需要大量繁琐的人工操作。所以讲想开发一款轻量级的曲线绘制交互软件。下面就以此为案例,记录一下基于Qt4的开发过程。 目录 1 需求 2 技术路线 3 开发流程 1 框架搭建 2 菜单…