4.grid_sample理解与使用

pytorch中的grid_sample

文章目录

  • pytorch中的grid_sample
    • grid_sample
    • `grid_sample`函数原型
    • 实例


欢迎访问个人网络日志🌹🌹知行空间🌹🌹


grid_sample

直译为网格采样,给定一个mask patch,根据在目标图像上的坐标网格,将mask变换到目标图像上。

如上图,是将一个2x2mask根据坐标网格grid变换到6x6目标图像x0 y0 x1 y1 = 1,1,3,3的位置上,值得注意的是grid是经过运算得到的坐标网格,masktarget image对应位置的左上角处坐标应该为-1,-1,右下角处坐标应该为1,1,目标图像对应位置的像素值由mask通过插值得到。

知道了grid_sample的原理,再来看下torch中的函数。

grid_sample函数原型

torch.nn.functional.grid_sample(input,grid, mode='bilinear',                padding_mode='zeros', align_corners=None)
  • input输入image patch,支持4d5d输入。为4dshape N , C , H i n , W i n N,C,H_{in},W_{in} N,C,Hin,Win
  • grid坐标网格,当input4d时其shape N , H o u t , W o u t , 2 N,H_{out},W_{out},2 N,Hout,Wout,2,输出的shapeN,C,H_{out},W_{out},对于输出的位置output[n, :, h, w],‵grid[n, h, w]是二维向量,指定了其对应的input上的位置。output[n, :, h, w]根据‵grid[n, h, w]指定的对应input位置上的像素插值得到。grid指定了在input输入维度上标准化后的坐标大小,input左上角对应的应该是-1,-1,右下角对应的是1,1
  • mode插值方式,'bilinear' | 'nearest' | 'bicubic'
  • padding_mode,在(-1,1)外的输出图像上的像素值处理方式'zeros' | 'border' | 'reflection'
  • align_corners:是否对齐角

实例

以将一个100x100mask,网格采样到500x300的图像上(x,y,w,h)=(100, 100, 100, 200)为例,看一下grid_sample是如何使用的。

先计算grid,


import torch
import numpy as np
import cv2
import torch.nn.functional as F
import matplotlib.pyplot as plth, w = 300, 500
x0, y0, x1, y1 = torch.tensor([[100]]), torch.tensor([[100]]), torch.tensor([[200]]), torch.tensor([[300]])
N = 1
x0_int, y0_int = 0, 0
x1_int, y1_int = 500, 300
img_y = torch.arange(y0_int, y1_int, dtype=torch.float32) + 0.5
img_x = torch.arange(x0_int, x1_int, dtype=torch.float32) + 0.5
img_y = (img_y - y0) / (y1 - y0) * 2 - 1
img_x = (img_x - x0) / (x1 - x0) * 2 - 1gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
grid = torch.stack([gx, gy], dim=3)

这里使用的是mask在目标图像上的大小来对grid归一化的。

mask = np.zeros((100, 100), dtype=np.uint8)
ct = np.array([[50, 0],[99, 50], [50, 99], [0, 50]], dtype=np.int32)
mask = cv2.drawContours(mask, [ct], -1, 255,  cv2.FILLED)
plt.figure(1)
plt.imshow(mask)
mask = torch.from_numpy(mask)
masks = mask[None, None, :]if not torch.jit.is_scripting():if not masks.dtype.is_floating_point:masks = masks.float()img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False)
plt.figure(2)
plt.imshow(img_masks.squeeze().numpy().astype(np.uint8))

根据gridmask映射到目标图像上的指定区域指定大小。


1.https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/247066.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言 操作符详解

C语言学习 目录 文章目录 前言 一、算术操作符 二、移位操作符 2.1 左移操作符 2.2 右移操作符 三、位操作符 3.1 按位与操作符 & 3.2 按位或操作符 | 3.3 按位异或操作符 ^ 四、赋值操作符 五、单目操作符 5.1 逻辑反操作符! 5.2 正值、负值-操作符 5.3 取地址…

前端项目中获取浏览器版本的方法

在我们的前端项目中,navigator.userAgent属性含有当前浏览器相关信息(比如版本号)。 所以当我们想要获取用户当前访问的浏览器的版本时直接去解析navigator.userAgent字段就中。 废话不多说,下面看封装的获取浏览器版本的函数&am…

Learning Normal Dynamics in Videos with Meta Prototype Network 论文阅读

文章信息:发表在cvpr2021 原文链接: Learning Normal Dynamics in Videos with Meta Prototype Network 摘要1.介绍2.相关工作3.方法3.1. Dynamic Prototype Unit3.2. 视频异常检测的目标函数3.3. 少样本视频异常检测中的元学习 4.实验5.总结代码复现&a…

STM32串口接收不定长数据(空闲中断+DMA)

玩转 STM32 单片机,肯定离不开串口。串口使用一个称为串行通信协议的协议来管理数据传输,该协议在数据传输期间控制数据流,包括数据位数、波特率、校验位和停止位等。由于串口简单易用,在各种产品交互中都有广泛应用。 但在使用串…

基础堆溢出原理与DWORD SHOOT实现

堆介绍 堆的数据结构与管理策略 程序员在使用堆时只需要做三件事情:申请一定大小的内存,使用内存,释放内存。 对于堆管理系统来说,响应程序的内存使用申请就意味着要在"杂乱"的堆区中"辨别"出哪些内存是正在…

Python的文件的读写操作【侯小啾Python基础领航计划 系列(二十七)】

Python_文件的读写操作【侯小啾Python基础领航计划 系列(二十七)】 大家好,我是博主侯小啾, 🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔…

实验案例二:多表查询

1、表联接类型。 表联接类型可以分为内联接.外联接和交叉联接等。 1.内联接。 内联接〈 inner join)是最常用的-一-种联接方式,只返回两个数据集合之间匹配关系的行,将位于两个互相交叉的数据集合中重叠部分以内的数…

“影响力”经济:抖音为什么更值得商家、达人长期深耕?

文|新熔财经 作者|叶一城 数亿的活跃用户,简单而自然的切入方式,快速、高频的执行效率,让抖音对电商界的冲击无可阻挡。 这背后,流量玩法登峰造极,是很多人的直接观感。 但实际上&#xff0…

原生横向滚动条 吸附 页面底部

效果图 /** 横向滚动条 吸附 页面底部 */ export class StickyHorizontalScrollBar {constructor(options {}) {const { el, style } optionsthis.createScrollbar(style)this.insertScrollbar(el)this.setScrollbarSize()this.onEvent()}/** 创建滚轴组件元素 */createS…

CCF CSP认证 历年题目自练Day51

此题又丑又长可以直接从题目分析(个人理解)部分看 题目 试题编号: 201812-3 试题名称: CIDR合并 时间限制: 1.0s 内存限制: 512.0MB 样例输入 2 1 2 样例输出 1.0.0.0/8 2.0.0.0/8 样例输入 2 10/9 10…

WEB渗透—反序列化(十一)

Web渗透—反序列化 课程学习分享(课程非本人制作,仅提供学习分享) 靶场下载地址:GitHub - mcc0624/php_ser_Class: php反序列化靶场课程,基于课程制作的靶场 课程地址:PHP反序列化漏洞学习_哔哩哔_…

【一周AI简讯】亚马逊推出企业级生成式AI聊天机器人,英伟达黄仁勋称AI将在5年内赶超人类

亚马逊推出企业级生成式AI聊天机器人Amazon Q 周二,亚马逊的云计算部门亚马逊网络服务 (AWS)推出了 Amazon Q,这是一款生成式 AI 聊天机器人。与 ChatGPT 和 Bard 不同,Amazon Q 并不基于单一的 AI 模型。相反,它在一个名为 Bedr…