Pytorch常用的函数(五)np.meshgrid()和torch.meshgrid()函数解析

Pytorch常用的函数(五)np.meshgrid()和torch.meshgrid()函数解析

我们知道torch.meshgrid()函数的功能是生成网格,可以用于生成坐标;

在numpy中也有一样的函数np.meshgrid(),但是用法不太一样,我们直接上代码进行解释。

1、两者在用法上的区别

比如:我要生成下图的xy坐标点,看下两者的实现方式:

在这里插入图片描述

np.meshgrid()

>>> import numpy as np
>>> w, h = 4, 2
# 注意,此时输入的是由w和h生成的一维数组
#      此时输出的是网格x的坐标grid_x以及网格y的坐标grid_y
>>> grid_x, grid_y  = np.meshgrid(np.arange(w), np.arange(h)) >>> grid_x
array([[0, 1, 2, 3],  [0, 1, 2, 3]])
>>> grid_y
array([[0, 0, 0, 0],[1, 1, 1, 1]])

torch.meshgrid()

>>> import torch
# 注意,此时输入的是由h和w生成的一维数组(和numpy中的输入顺序相反)
#      此时输出的是网格y的坐标grid_y以及网格x的坐标grid_x(和numpy中的输出顺序相反)
>>> grid_y, grid_x =  torch.meshgrid(
...         torch.arange(h),
...         torch.arange(w)
...     )
>>> grid_x
tensor([[0, 1, 2, 3],[0, 1, 2, 3]])
>>> grid_y
tensor([[0, 0, 0, 0],[1, 1, 1, 1]])

2、应用案例

2.1 利用np.meshgrid()来画决策边界

我们可以利用np.meshgrid()来画等高线图

# 等高线图
import numpy as np
import matplotlib.pyplot as plt# 模拟海拔高度
def fz(x, y):z = (1 -x / 2 + x**5 + y**3) * np.exp(-x**2-y**2)return zw = np.linspace(-4, 4, 100)
h = np.linspace(-2, 2, 100)grid_x, grid_y = np.meshgrid(w, h)
z = fz(grid_x, grid_y)plt.figure('Contour Chart',facecolor='lightgray')
plt.title('contour',fontsize=16)
plt.grid(linestyle=':')cntr = plt.contour(grid_x, # 网格坐标矩阵的x坐标(2维数组)grid_y, # 网格坐标矩阵的y坐标(2维数组)z,      # 网格坐标矩阵的z坐标(2维数组)8,      # 等高线绘制8部分colors = 'black', # 等高线图颜色linewidths = 0.5 # 等高线图线宽
)
# 设置标签
plt.clabel(cntr, inline_spacing = 1, fmt='%.2f', fontsize=10)
# 填充颜色  大的是红色  小的是蓝色
plt.contourf(grid_x, grid_y, z, 8, cmap='jet')plt.legend()
plt.show()

在这里插入图片描述

我们可以利用np.meshgrid()来画决策边界。

from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
import numpy as npfrom sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC# 使用sklearn自带的moon数据
X, y = make_moons(n_samples=100,noise=0.15,random_state=42)# 绘制生成的数据
def plot_dataset(X,y,axis):plt.plot(X[:,0][y == 0],X[:,1][y == 0],'bs')plt.plot(X[:,0][y == 1],X[:,1][y == 1],'go')plt.axis(axis)plt.grid(True,which='both')# 画出决策边界
def plot_pred(clf,axes):w = np.linspace(axes[0],axes[1], 100)h = np.linspace(axes[2],axes[3], 100)grid_x, grid_y = np.meshgrid(w, h)# grid_x 和 grid_y 被拉成一列,然后拼接成10000行2列的矩阵,表示所有点grid_xy = np.c_[grid_x.ravel(), grid_y.ravel()]# 二维点集才可以用来预测y_pred = clf.predict(grid_xy).reshape(grid_x.shape)# 等高线plt.contourf(grid_x, grid_y,y_pred,alpha=0.2)ploy_kernel_svm_clf = Pipeline(steps=[("scaler",StandardScaler()),("svm_clf",SVC(kernel='poly', degree=3, coef0=1, C=5))]
)ploy_kernel_svm_clf.fit(X,y)plot_pred(ploy_kernel_svm_clf,[-1.5, 2.5, -1, 1.5])
plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])
plt.show()

在这里插入图片描述

2.2 利用torch.meshgrid()生成网格所有坐标的矩阵

在目标检测YOLO中将图像划分为单元网格的部分就用到了torch.meshgrid()函数。

import torch
import numpy as npdef create_grid(input_size, stride=32):# 1、获取原始图像的w和hw, h = input_size, input_size# 2、获取经过32倍下采样后的feature mapws, hs = w // stride, h // stride# 3、生成网格的y坐标和x坐标grid_y , grid_x = torch.meshgrid([torch.arange(hs),torch.arange(ws)])# 4、将grid_x和grid_y进行拼接,拼接后的维度为【H, W, 2】grid_xy = torch.stack([grid_x, grid_y], dim=-1).float()# 【H, W, 2】 -> 【HW, 2】grid_xy = grid_xy.view(-1, 2)return grid_xyif __name__ == '__main__':print(create_grid(input_size=32*4))
# 生成网格所有坐标的矩阵
tensor([[0., 0.],[1., 0.],[2., 0.],[3., 0.],[0., 1.],[1., 1.],[2., 1.],[3., 1.],[0., 2.],[1., 2.],[2., 2.],[3., 2.],[0., 3.],[1., 3.],[2., 3.],[3., 3.]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/293434.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

nuxt3快速上手

1.安装&#xff1a; npx nuxi init project-name # project-name 是项目名,如果下载不下来请挂梯子。2.安装依赖&#xff1a; npm install3.运行项目&#xff1a; npm run dev4.代码解释&#xff1a; <template><!-- app.vue 是所有页面的入口&#xff1a; --&g…

力扣题目学习笔记(OC + Swift)19. 删除链表的倒数第 N 个结点

19. 删除链表的倒数第 N 个结点 给你一个链表&#xff0c;删除链表的倒数第 n 个结点&#xff0c;并且返回链表的头结点。 此题目为链表题&#xff0c;拿出我们的杀手锏&#xff0c;链表解题经典三把斧&#xff1a; 哑巴节点栈快慢指针 关于内存问题&#xff1a;由于Swift及…

选择移动订货系统源码的四大原因

移动订货系统需要选择源码支持的厂家&#xff0c;有以下四个原因&#xff0c;其中第四个是比较重要的&#xff0c;大家点个关注点个赞&#xff0c;我们接着往下看。 1.可自行定制&#xff1a;支持源码的移动订货系统可以根据企业的具体需求进行定制开发&#xff0c;满足企业特定…

pyqt5实现wget下载视频文件的进度条显示

简介&#xff1a; 最近在写一个项目&#xff0c;用到了wget下载视频&#xff0c;为了更好的视觉效果&#xff0c;所以使用pyqt5中QProgressBar来实现下载进度条。当视频开始下载就会弹出下载进度条&#xff0c;下载完成后进度条消失。效果如下图; 具体代码实现 &#xff1a; …

FastGPT+ChatGLM3-6b搭建知识库

前言&#xff1a;我用fastgpt直接连接chatglm3&#xff0c;没有使用oneai&#xff0c;不是很复杂&#xff0c;只需要对chatglm3项目代码做少量修改就能支持使用embeddings&#xff0c;向量模型用的m3e&#xff0c;效果还可以 我的配置&#xff1a; 处理器&#xff1a;i5-13500 …

华为云Stack 8.X流量模型分析(三)

三、VPC内部二层流量模型分析 1.不同宿主机下虚拟机互访 VM1发送arp请求&#xff0c;arp报文根据流表到达br-tun&#xff0c;br-tun给予VM1到达VM2的MAC信息。此时arp报文不出宿主机&#xff08;Host1&#xff09;&#xff1b; **注意&#xff1a;**br-tun内的信息是由管理平…

Ignite分布式缓存框架

1.前言 Apache Ignite是一个分布式数据库&#xff0c;支持以内存级的速度进行高性能计算。 2。快速入门 本章节介绍运行Ignite的系统要求&#xff0c;如何安装&#xff0c;启动一个集群&#xff0c;然后运行一个简单的HelloWorld示例。 2.1.环境要求 Apache Ignite官方在如…

110基于matlab的混合方法组合的极限学习机和稀疏表示进行分类

基于matlab的混合方法组合的极限学习机和稀疏表示进行分类。通过将极限学习机&#xff08;ELM&#xff09;和稀疏表示&#xff08;SRC&#xff09;结合到统一框架中&#xff0c;混合分类器具有快速测试&#xff08;ELM的优点&#xff09;的优点&#xff0c;且显示出显着的分类精…

关于频谱仪是如何来实现辐射功率测量

1.1 内部基本原理框架 首先是接收到外部信号输入&#xff0c;然后经过可变衰减器衰减&#xff0c;接着进行变频&#xff0c;接着经过带宽带通滤波器进行滤波&#xff0c;滤波后的信号送入检波器进行信号检测&#xff0c;再经对数放大器放大后&#xff0c;送入低通滤波器进行视频…

java.lang.IllegalStateException: Duplicate key

序言 最近监控扫描出我们项目的某些异常信息&#xff0c;报错java.lang.IllegalStateException: Duplicate key xxx&#xff0c;看到异常来自stream流&#xff0c;然后定位看了一下是某位同事的代码使用stream流把List转Map集合出现重复的key异常信息。List集合A对象来源于某个…

鸿蒙ArkTS语言介绍与TS基础语法

1、ArkTS介绍 ArkTS是HarmonyOS主力应用开发语言&#xff0c;它在TS基础上&#xff0c;匹配ArkUI框架&#xff0c;扩展了声明式UI、状态管理等响应的能力&#xff0c;让开发者以更简洁、更自然的方式开发跨端应用。 JS 是一种属于网络的高级脚本语言&#xff0c;已经被广泛用…

FunBox11靶场 安装下载渗透详细教程

一. 下载靶场 官网下载地址 二. 安装 1.导入FunBox11 三.修改键盘布局和修改IP 参考历史文庄FunBox9安装教程 四. 打靶 1.提供arp-scan工具扫描网络主机IP arp-scan -l -i eh1 2.通过nmap 对目标主机进行扫描 nmap -A -p- -T5 172.30.1.134 -A : 启动Os检测&#xff0c;版…