【深度学习】张量的广播专题

一、说明

张量广播(tensor broadcasting)是一种将低维张量自动转化为高维张量的技术,使得张量之间可以进行基于元素的运算(如加、减、乘等)。在进行张量广播时,会将维度数较少的张量沿着长度为1的轴进行复制,在匹配维度后,两个张量就可以进行运算。

二、张量的基本概念

        当较小的张量被“拉伸”以具有与较大张量的兼容形状以执行操作时,就会发生广播。

 

广播可以成为执行张量运算而不创建重复数据的有效方法。

根据 PyTorch 的说法,在以下情况下,张量是“可广播的”:

每个张量至少有一个维度

循环访问维度大小时,从尾随维度开始,维度大小必须相等、其中一个为 1,或者其中一个不存在

比较形状时,尾随维度是最右边的数字。

在上图中,可以看到通用过程:

1. 确定最右侧的尺寸是否兼容

  • 每个张量是否至少有一个维度?
  • 大小相等吗?其中之一吗?一个不存在吗?

2. 将尺寸拉伸到适当的尺寸

3. 对下一个维度重复上述步骤

这些步骤可以在下面的示例中看到。

三、元素级操作

        所有元素级运算都要求张量具有相同的形状。

3.1 矢量和标量示例

import torch
a = torch.tensor([1, 2, 3])
b = 2 # becomes ([2, 2, 2])a * b
tensor([2, 4, 6])

        在此示例中,标量的形状为 (1,),矢量的形状为 (3,)。如图所示,b被广播为(3,)的形状,并且Hadamard乘积按预期执行。

3.2 矩阵和矢量示例 1

 

        在此示例中,A 的形状为 (3, 3),的形状为 (3,)。

发生乘法时,向量被逐行拉伸以创建一个矩阵,如上图所示。现在,A 和 b 的形状均为 (3, 3)。

        这可以在下面看到。


A = torch.tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])b = torch.tensor([1, 2, 3])A * b
tensor([[ 1,  4,  9],[ 4, 10, 18],[ 7, 16, 27]])

3.3 矩阵和矢量示例 2

 

        在此示例中,的形状为 (3, 3),的形状为 (3, 1)。

        发生乘法时,向量将逐列拉伸以创建两个额外的列,如上图所示。现在,A 和 b 的形状均为 (3, 3)。

A = torch.tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])b = torch.tensor([[1], [2], [3]])
A * b
tensor([[ 1,  2,  3],[ 8, 10, 12],[21, 24, 27]])

Tensor and Vector Example

         在此示例中,是形状为 (2, 3, 3) 的张量,是形状为 (3, 1) 的列向量。

A = (2, 3, 3)
b = ( , 3, 1)

        从最右边的维度开始,每个元素按列拉伸以生成 (3, 3) 矩阵。中间维度相等。在这一点上,b只是一个矩阵。最左侧的维度不存在,因此必须添加一个维度。然后,必须广播矩阵以创建 (2, 3, 3) 的大小。现在有两个 (3, 3) 个矩阵,可以在上图中看到。

        这允许计算 Hadamard 乘积并生成 (2, 3, 3) 矩阵:

A = torch.tensor([[[1, 2, 3],[4, 5, 6],[7, 8, 9]],[[1, 2, 3],[4, 5, 6],[7, 8, 9]]])b = torch.tensor([[1], [2], [3]])A * b
tensor([[[ 1,  2,  3],[ 8, 10, 12],[21, 24, 27]],[[ 1,  2,  3],[ 8, 10, 12],[21, 24, 27]]])

3.4 张量和矩阵示例

        在此示例中,是形状为 (2, 3, 3) 的张量,是形状为 (3, 3) 的矩阵。

A = (2, 3, 3)
B = ( , 3, 3)

        此示例比上一个示例更容易,因为最右侧的两个维度是相同的。这意味着矩阵只需在最左侧的维度上广播即可创建 (2, 3, 3) 的形状。这只是意味着需要一个额外的矩阵。

        计算哈达玛乘积时,结果为 (2, 3, 3)。

A = torch.tensor([[[1, 2, 3],[4, 5, 6],[7, 8, 9]],[[1, 2, 3],[4, 5, 6],[7, 8, 9]]])B = torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]])A * B
tensor([[[ 1,  4,  9],[ 4, 10, 18],[ 7, 16, 27]],[[ 1,  4,  9],[ 4, 10, 18],[ 7, 16, 27]]])

四、矩阵和张量乘法与点积

        对于前面的所有示例,目标是以相同的形状结束,以允许逐元素乘法。此示例的目标是通过点积实现矩阵和张量乘法,这需要第一个矩阵或张量的最后一个维度与第二个矩阵或张量的倒数第二个维度匹配。

        对于矩阵乘法:

  • (m, n) x (n, r) = (c, m, r)

        对于 3D 张量乘法:

  • (c, m, n) x (c, n, r) = (c, m, r)

对于 4D 张量乘法:

  • (z, c, m, n) x (z, c, n, r) = (z, c, m, r)

        对于此示例,A 的形状为 (2, 3, 3),的形状为 (3, 2)。截至目前,最后两个维度符合点积乘法的条件。需要将维度添加到 B,并且需要跨此维度广播 (3, 2) 矩阵以创建 (2, 3, 2) 的形状。

        此张量乘法的结果将是 (2, 3, 3) x (2, 3, 2) = (2, 3, 2)。

A = torch.tensor([[[1, 2, 3],[4, 5, 6],[7, 8, 9]],[[1, 2, 3],[4, 5, 6],[7, 8, 9]]])B = torch.tensor([[1, 2], [1, 2], [1, 2]])A @ B # A.matmul(B)
tensor([[[ 6, 12],[15, 30],[24, 48]],[[ 6, 12],[15, 30],[24, 48]]])

        有关广播的其他信息可以在下面的链接中找到。有关张量及其操作的更多信息可以在此处找到。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/27834.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【深度学习】了解残差网 ResNet 和 ResNeXt 的架构

一、说明 了解和实现 ResNet 和 ResNeXt 的架构以实现最先进的图像分类:从Microsoft到 Facebook [第 1 部分],在这篇由两部分组成的博客文章中,我们将探讨残差网络。更具体地说,我们将讨论Microsoft研究和Facebook AI研究发布的三…

【C++初阶】list的模拟实现 附源码

一.list介绍 list底层是一个双向带头循环链表,这个我们以前用C语言模拟实现过,->双向带头循环链表 下面是list的文档介绍: list文档介绍 我们会根据 list 的文档来模拟实现 list 的增删查改及其它接口。 二.list模拟实现思路 既然是用C模拟…

【Vue】给 elementUI 中的 this.$confirm、this.$alert、 this.$prompt添加按钮的加载效果

文章目录 主要使用 beforeClose 方法实现 loading 的效果beforeClose MessageBox 关闭前的回调,会暂停实例的关闭 function(action, instance, done)1. action 的值为confirm, cancel或close。 2. instance 为 MessageBox 实例,可以通过它访问实例上的属…

Django + Bootstrap - 【echart】 统计图表进阶使用-统计用户日活日增、月活月增等数据(二)

一. 前言 Bootstrap是一个流行的前端框架,而ECharts是一个流行的可视化库。 Bootstrap可以用来设计网站和应用程序的用户界面,而ECharts可以用来创建交互式和可视化的图表。 chart.js中文文档:http://www.bootcss.com/p/chart.js/docs/ 二. …

LT8619B 是一款HDMI转TTL或者2 PORT LVDS的芯片。

LT8619B 1. 概述 LT8619B是龙迅基于清除边缘技术的高性能HDMI接收芯片,符合HDMI 1.4(高清多媒体接口)规范。RGB 输出端口可支持 RGB888/RGB666/RGB565 格式,输出分辨率最高可支持 4Kx2K 分辨率。凭借可编程标量,LT86…

切换.net Framework 版本后,出现NuGet 包是使用不同于当前目标框架的目标框架安装的,可能需要重新安装

问题现象: 由于添加新的dll文件,依赖的.NET Framework版本与当前的不一致,在vs 中切换了目标框架版本后,运行程序,出现以下的warnning信息: 一些 NuGet 包是使用不同于当前目标框架的目标框架安装的&#…

控制对文件访问

控制对文件访问 Linux文件权限 权限文件影响目录影响r读取文件内容列出目录内容w更改文件内容创建删除目录文件x作为命令执行目录可以变成当前工作目录 命令行管理文件系统权限 更改文件和目录权限 chmod chmod WhoWhatWhich file|directoryWho (u,g,o,a代表用户&#xff…

分布式事务 Seata

分布式事务 Seata 事务介绍分布式理论Seata 介绍Seata 部署与集成Seata TC Server 部署微服务集成 Seata XA 模式AT 模式AT 模式执行过程读写隔离写隔离读隔离 实现 AT 模式 TCC 模式TCC 模式介绍实现 TCC 模式 Saga 模式Seata 四种模式对比 事务介绍 事务(Transac…

使用GPU进行大规模并行仿真,解决强化学习采样瓶颈:CPU、GPU架构以及原理详解

强化学习的落地应用场景,我认为可以是仿真环境仿真程度高,且仿真速度快的任务场景。而这篇帖子将会将:使用 GPU 进行大规模并行仿真,解决强化学习采样瓶颈。并直接举出三个例子,展示如何对原有的仿真环境进行修改,让它们适应 GPU 并行加速。 1.强化学习论文背后的仿真环…

降级npm后,出现xxx 不是内部或外部命令解决方法

比如我安装了anyproxy npm install anyproxy -g 之后在cmd中输入anyproxy 发现 anyproxy 不是内部或外部命令解决方法. 一般出现这样的问题原因是npm安装出现了问题,全局模块目录没有被添加到系统环境变量。 Windows用户检查下npm的目录是否加入了系统变量P…

【JavaEE】Tomcat的安装和使用、创建Mevan项目使用Servlet写一个程序

目录 前言 一、Tomcat的下载和安装 二、写一个简单的Servlet项目 1、创建一个Maven项目 2、引入依赖 3、创建目录 4、编写Servlet代码。 5、打包程序 6、将程序部署到Tomcat上 7、验证程序运行结果 三、在IDEA上安装Smart Tomcat插件 四、Servlet中的一些常见错误 …

计算机网络 day9 DNAT实验

目录 DNAT DNAT策略的典型应用环境 DNAT策略的原理 在网关中使用DNAT策略发布内网服务器 DNAT实验: 实验环境: DNAT网络规划拓扑图: 步骤: 1、创建linux客户端Web网站(go语言),实现Web服…