pytorch花式索引提取topk的张量

文章目录

  • pytorch花式索引提取topk的张量
    • 问题设定
    • 代码实现
      • 索引方法
      • gather方法
      • 验证
    • 补充知识
      • expand方法
      • gather方法
      • randint

pytorch花式索引提取topk的张量

问题设定

在这里插入图片描述
或者说,有一个(bs, dim, L)的大张量,索引的index形状为(bs, X),想得到一个(bs, dim, X)的reduced向量。我们在进行topk操作(以减少计算量)的时候经常碰到这种情况。
给出如下两种实现方法,分别使用花式索引(参考informer的代码)以及pytorch的gather方法

代码实现

索引方法

参考https://blog.csdn.net/qq_36560894/article/details/122005808

feature = torch.rand(2,16,4*4)
indices = torch.randint(0,16, (2, 3))
indices
indices_expand = indices.unsqueeze(1).expand(-1, dim, -1).to(torch.long) # (bs, dim, H*W)
indices_expand.shape
indices_expand[:,1,:] # 结果和indices一致,说明在第二个channel上,每个样本的索引是一样的
bs,dim=feature.shape[:2]
bs,dim 
feature_reduce = feature.view(bs, dim, -1)[torch.arange(bs)[:, None, None], torch.arange(dim)[None,:,None], indices_expand]
feature_reduce.shape

在这里插入图片描述
在这里插入图片描述

gather方法

reduce_feature = torch.gather(feature, 2, indices_expand)

验证

两种方法得到的结果完全相同
在这里插入图片描述

补充知识

expand方法

在 PyTorch 中,expand() 方法用于扩展张量的大小。它会在不实际复制数据的情况下,重复张量的元素以填充新的形状。这个方法可以用于广播操作,以便在执行一些需要相同形状的张量之间的数学运算时,使它们具有相同的形状。

下面是使用 expand() 方法的基本用法:

import torch# 创建一个原始张量
x = torch.tensor([[1, 2, 3],[4, 5, 6]])# 使用 expand 扩展张量的大小
expanded_x = x.expand(2, 3, 4)  # 扩展成维度为(2, 3, 4)的张量print(expanded_x)

在上面的例子中,我们首先创建了一个形状为 (2, 3) 的原始张量 x。然后,我们使用 expand() 方法将其扩展成一个维度为 (2, 3, 4) 的新张量 expanded_x,该张量的形状是在原始张量形状的基础上每个维度都扩展了一倍。

需要注意的是,expand() 方法只能用于增加张量的大小,不能减小。另外,扩展后的张量与原始张量共享底层数据,因此在原始张量上进行的任何修改都会反映在扩展后的张量上,反之亦然。

gather方法

在 PyTorch 中,gather() 方法用于从输入张量中按照指定索引提取元素。这个方法通常用于根据索引收集特定的元素,例如根据类别索引从分类得分张量中获取对应类别的得分。

下面是使用 gather() 方法的基本用法:

import torch# 创建一个输入张量
input_tensor = torch.tensor([[1, 2],[3, 4],[5, 6]])# 创建一个索引张量
indices = torch.tensor([[0, 0],[1, 0]])# 使用 gather 方法根据索引收集元素
output_tensor = torch.gather(input_tensor, dim=1, index=indices)print(output_tensor)

在上面的例子中,我们首先创建了一个形状为 (3, 2) 的输入张量 input_tensor,以及一个形状为 (2, 2) 的索引张量 indices。然后,我们使用 gather() 方法从输入张量 input_tensor 中按照索引张量 indices 收集元素。

gather() 方法中,参数 dim 指定了在哪个维度上进行收集操作,而 index 参数指定了收集元素所使用的索引张量。

需要注意的是,索引张量 indices 的形状必须与输出张量的形状一致,或者是可以广播成与输出张量形状一致的形状。

randint

torch.randint() 是 PyTorch 中用于生成随机整数张量的函数。它可以生成一个张量,其中的元素是在指定范围内随机抽样的整数。

下面是 torch.randint() 的基本用法示例:

import torch# 生成一个形状为 (3, 3) 的随机整数张量,范围是 [0, 10)
random_integers = torch.randint(low=0, high=10, size=(3, 3))print(random_integers)

在上面的示例中,我们使用了 torch.randint() 函数来生成一个形状为 (3, 3) 的随机整数张量,其中的元素取值范围在闭区间 [low, high) 内,即从 0 到 9。

torch.randint() 函数的主要参数包括:

  • low:生成的随机整数的最小值(包含)。
  • high:生成的随机整数的最大值(不包含)。
  • size:生成的张量的形状。

你也可以不指定 low 参数,默认情况下它为 0。此外,还可以使用其他参数来控制生成的随机整数张量的设备类型、数据类型等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/467493.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

APP inventor零基础移动应用开发

1.Android平台简介 Android由谷歌和开放手机联盟共同创建的一款针对手机的开源软件工具包 主要特色 ---开放性 – 丰富的硬件选择 – 开发商不受任何限制 – 无缝集成互联网服务 App Inventor是由Google公司开发的一款在线开放的Android编程工具软件,通过图形化…

涛哥聊Python | pymunk,一个强大的 Python 库!

本文来源公众号“涛哥聊Python”,仅用于学术分享,侵权删,干货满满。 原文链接:pymunk,一个强大的 Python 库! 大家好,今天为大家分享一个强大的 Python 库 - pymunk。 Github地址:…

STM32 + ESP8266,连接阿里云 上报/订阅数据

(文章正在编辑中,一点点地截图操作过程,估计要拖拉两三天) 一、烧录MQTT固件 ESP8266出厂时,默认是AT固件。连接阿里云,需要使用MQTT固件。 1、独立EPS8266模块的烧录方法 2、魔女开发板,板载…

项目排期 - 华为OD统一考试

OD统一考试(C卷) 分值: 200分 题解: Java / Python / C 题目描述 项目组共有N个开发人员,项目经理接到了M个独立的需求,每个需求的工作量不同,且每个需求只能由一个开发人员独立完成&#xff0…

响应式编程三流处理

响应式编程三流处理 组合响应式流concatmergezipcombineLatest flatMap、concatMap、flatMapSequebtial操作符flatMapconcatMapflatMapSequential 元素采样sample 和sampleTimeout 流的批处理bufferwindow操作符group by将响应式流转化为阻塞结构在序列处理时查看元素物化和非物…

day 20(补2.5)

fread 函数: 今日练习 C语言面试题5道~ 1. static 有什么用途?(请至少说明两种) 1) 限制变量的作用域 2) 设置变量的存储域 2. 引用与指针有什么区别? 1) 引用必须被初始化,指针不必。 2) 引用初始…

uniapp微信小程序开发踩坑日记:Pinia持久化

如果你使用过Pinia,那你应该知道Pinia持久化插件:https://prazdevs.github.io/pinia-plugin-persistedstate/zh/ 但由于官方文档提供的说明并不是针对小程序开发,所以我们在使用这个插件实现uniapp小程序开发中Pinia持久化会出现问题 我在C…

【深度学习】S2 数学基础 P1 线性代数(上)

目录 基本数学对象标量与变量向量矩阵张量降维求和非降维求和累计求和 点积与向量积点积矩阵-向量积矩阵-矩阵乘法 深度学习的三大数学基础 —— 线性代数、微积分、概率论; 自本篇博文以下几遍博文,将对这三大数学基础进行重点提炼。 本节博文将介绍线…

Illegal escape character in string literal

问题 笔者进行Android项目开发,编译器提示报错 Illegal escape character in string literal详细问题 textView.setText(“A\B”); 解决方案 修改代码为A\B textView.setText(“A\B”) 产生原因 问题产生的原因是在字符串字面值中使用了非法的转义字符。在…

猫头虎分享已解决Bug ‍ || 错误SyntaxError: invalid syntax(Python)的解决方法

博主猫头虎的技术世界 🌟 欢迎来到猫头虎的博客 — 探索技术的无限可能! 专栏链接: 🔗 精选专栏: 《面试题大全》 — 面试准备的宝典!《IDEA开发秘籍》 — 提升你的IDEA技能!《100天精通鸿蒙》 …

编写代码(LLVM的第一个项目)

下面这个完整代码 它相对较短,因为它建立在LLVM 流程的基础设施上 后者替我们完成大部分工作 我们从程序使用cl命名空间中的llvm工具(cl代表命令行)来实现我们的命令行接口 需要调用ParseCommandLineOption函数声明cl:&#xff…

滑动小短剧影视微信小程序源码/带支付收益等模式

仿抖音滑动小短剧影视微信小程序源码,带支付收益等模式、支持无限滑动;高性能滑动、预加载、视频预览,支持剧情介绍,集合壁纸另外仿抖音滑动效果;支持会员模式,支持用户单独购买等等多功能。 丰富的后台设…