Pytorch入门—Tensors张量的学习

Tensors张量的学习

张量是一种特殊的数据结构,与数组和矩阵非常相似。在PyTorch中,我们使用张量来编码模型的输入和输出,以及模型的参数。

张量类似于NumPy的ndarrays,只是张量可以在GPU或其他硬件加速器上运行。事实上,张量和NumPy数组通常可以共享相同的底层内存,从而无需复制数据(请参阅使用NumPy进行桥接)。张量还针对自动微分进行了优化(我们将在稍后的Autograd部分中看到更多内容)。如果您熟悉ndarrays,您将熟悉Tensor API。

import torch
import numpy as np

Initializing a Tensor 初始化张量

Directly from data 直接从数据中初始化

张量可以直接从数据中创建。数据类型是自动推断的。

data = [[1, 2],[3, 4]]
x_data = torch.tensor(data)

image-20240507094522422

From a NumPy array 从NumPy数组初始化

张量可以从NumPy数组中创建(反之亦然—请参阅使用NumPy进行桥接)。

np_array = np.array(data)
x_np = torch.from_numpy(np_array)

From another tensor 从另一个tensor初始化

新张量保留参数张量的属性(形状,数据类型),除非显式覆盖。

x_ones = torch.ones_like(x_data) # retains the properties of x_data
print(f"Ones Tensor: \n {x_ones} \n")x_rand = torch.rand_like(x_data, dtype=torch.float) # overrides the datatype of x_data
print(f"Random Tensor: \n {x_rand} \n")

image-20240507095106372

With random or constant values
具有随机值或常量值

shape 是张量维度的元组。在下面的函数中,它确定输出张量的维数。

shape = (2,3,)
rand_tensor = torch.rand(shape)
ones_tensor = torch.ones(shape)
zeros_tensor = torch.zeros(shape)print(f"Random Tensor: \n {rand_tensor} \n")
print(f"Ones Tensor: \n {ones_tensor} \n")
print(f"Zeros Tensor: \n {zeros_tensor}")

image-20240507095334820

Attributes of a Tensor 张量的属性

张量属性描述了它们的形状、数据类型以及存储它们的设备。

tensor = torch.rand(3,4)print(f"Shape of tensor: {tensor.shape}")
print(f"Datatype of tensor: {tensor.dtype}")
print(f"Device tensor is stored on: {tensor.device}")

image-20240507095546591

Standard numpy-like indexing and slicing
标准的numpy式索引和切片

tensor = torch.ones(4, 4)
print(f"First row: {tensor[0]}")
print(f"First column: {tensor[:, 0]}")
print(f"Last column: {tensor[..., -1]}")
tensor[:,1] = 0
print(tensor)

image-20240507100001132

Joining tensors 连接张量

连接张量您可以使用 torch.cat 将一系列张量沿着给定的维度连接起来。另请参见torch.stack,这是另一个与 torch.cat 略有不同的张量连接运算符。

t1 = torch.cat([tensor, tensor, tensor], dim=1)
print(t1)

image-20240507100440770

Arithmetic operations 算术运算

# This computes the matrix multiplication between two tensors. y1, y2, y3 will have the same value
# ``tensor.T`` returns the transpose of a tensor
y1 = tensor @ tensor.T
y2 = tensor.matmul(tensor.T)y3 = torch.rand_like(y1)
torch.matmul(tensor, tensor.T, out=y3)# This computes the element-wise product. z1, z2, z3 will have the same value
z1 = tensor * tensor
z2 = tensor.mul(tensor)z3 = torch.rand_like(tensor)
torch.mul(tensor, tensor, out=z3)

这段代码主要演示了如何在PyTorch中进行矩阵乘法和元素级乘法。

  1. 矩阵乘法:

    y1 = tensor @ tensor.Ty2 = tensor.matmul(tensor.T) 这两行代码都在进行矩阵乘法。@操作符和matmul函数都可以用于矩阵乘法。tensor.T返回tensor的转置。

    y3 = torch.rand_like(y1) 创建了一个与y1形状相同,元素为随机数的新tensor。

    torch.matmul(tensor, tensor.T, out=y3) 这行代码也在进行矩阵乘法,但是结果被直接写入了y3,而不是创建新的tensor。

  2. 元素级乘法:

    z1 = tensor * tensorz2 = tensor.mul(tensor) 这两行代码都在进行元素级乘法。*操作符和mul函数都可以用于元素级乘法。

    z3 = torch.rand_like(tensor) 创建了一个与tensor形状相同,元素为随机数的新tensor。

    torch.mul(tensor, tensor, out=z3) 这行代码也在进行元素级乘法,但是结果被直接写入了z3,而不是创建新的tensor。

矩阵乘法与元素级乘法是什么?

矩阵乘法和元素级乘法是两种不同的数学运算。

  1. 矩阵乘法:也被称为点积,是一种二元运算,将两个矩阵相乘以产生第三个矩阵。假设我们有两个矩阵A和B,A的形状是(m, n),B的形状是(n, p),那么我们可以进行矩阵乘法得到一个新的矩阵C,其形状是(m, p)。C中的每个元素是通过将A的行向量和B的列向量对应元素相乘然后求和得到的。
  2. 元素级乘法:也被称为Hadamard积,是一种二元运算,将两个矩阵相乘以产生第三个矩阵。假设我们有两个形状相同的矩阵A和B,那么我们可以进行元素级乘法得到一个新的矩阵C,其形状与A和B相同。C中的每个元素是通过将A和B中对应位置的元素相乘得到的。

在Python的NumPy和PyTorch库中,你可以使用@matmul函数进行矩阵乘法,使用*mul函数进行元素级乘法。

Single-element tensors

单元素张量

如果你有一个单元素张量,例如通过将张量的所有值聚合为一个值,你可以使用 item() 将它转换为Python数值。

agg = tensor.sum()
agg_item = agg.item()
print(agg_item, type(agg_item))

image-20240507102052385

In-place operations

就地操作

将结果存储到操作数中的操作称为就地操作。它们由 _ 后缀表示。例如: x.copy_(y)x.t_() ,将更改 x

print(f"{tensor} \n")
tensor.add_(5)
print(tensor)

image-20240507102216996

NOTE 注意
就地操作保存一些内存,但是在计算导数时可能会出现问题,因为会立即丢失历史。因此,不鼓励使用它们。

Bridge with NumPy

CPU和NumPy数组上的张量可以共享它们的底层内存位置,改变一个就会改变另一个。

张量到NumPy数组

t = torch.ones(5)
print(f"t: {t}")
n = t.numpy()
print(f"n: {n}")

image-20240507102621371

张量的变化反映在NumPy数组中。

t.add_(1)
print(f"t: {t}")
print(f"n: {n}")

image-20240507102720944

NumPy数组到张量

n = np.ones(5)
t = torch.from_numpy(n)

NumPy数组中的变化反映在张量中。

np.add(n, 1, out=n)
print(f"t: {t}")
print(f"n: {n}")

image-20240507102955148

Notebook来源:

[Tensors - PyTorch Tuesday 2.3.0+ cu 121文档 — Tensors — PyTorch Tutorials 2.3.0+cu121 documentation](

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/681814.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

音转文工具,9.8k star! 【送源码】

我们经常会遇到将音频转为文字的情况,比如在开会时录音的会议纪要、上课时录下的老师讲课内容。虽然网上也有一些在线的工具可以将音频转为文字,但是考虑到数据安全和费用问题,使用起来也不是很方便。 今天了不起给大家介绍一款开源工具——…

#友元函数与友元类

目录 1.概念 2.友元函数 3.友元类 1.概念 友元提供了一种突破封装的方式,有时提供了便利。但是友元会增加耦合度,破坏了封装,所以友元不宜多 用。 友元分为:友元函数和友元类 2.友元函数 友元函数可以直接访问类的私有成员&a…

「MDN web 入门」学习笔记

目录 写在前面 1. MDN 简介 1.1 MDN 的主要特点 1.2 MDN 的主要功能 1.3 MDN 网页开发的指南 2. 安装基础软件 2.1 专业人士工具 2.2 初学者基本工具 3. 设计网站外观 3.1 计划 3.2 绘制草图 3.3 选定素材 3.4 文本 3.5 主题颜色 3.6 图像 3.7 字体 4. 处理文…

六西格玛项目的核心要素:理论学习、实践应用与项目经验

许多朋友担心,没有项目经验是否就意味着无法考取六西格玛证书。针对这一疑问,张驰咨询为大家详细解答。 首先,需要明确的是,六西格玛项目不仅仅是一种管理工具或方法,更是一种追求卓越、持续改进的思维方式。它强调通…

5.07 Pneumonia Detection in Chest X-Rays using Neural Networks

肺炎诊断是一个耗时的过程,需要高技能的专业人员分析胸部X光片chest X-ray (CXR),并通过临床病史、生命体征和实验室检查确认诊断。 它可以帮助医生确定肺部感染的程度和位置。呼吸道疾病在 X 光片上表现为一处膨胀的不透明区域。然而,由于不…

科技云报道:从亚运到奥运,大型国际赛事共赴“云端”

科技云报道原创。 “广播电视转播技术拯救了奥运会”前奥委会主席萨马兰奇这句话广为流传。 奥运会、世界杯、亚运会这样的全球大型体育赛事不仅是体育竞技的盛宴,也是商业盛宴,还是技术与人文的融合秀。随着科技的进步,技术在体育赛事中扮…

2.外卖点餐系统(Java项目 springboot)

目录 0.系统的受众说明 1.系统功能设计 2.系统结构设计 3.数据库设计 3.1实体ER图 3.2数据表 4.系统实现 4.1用户功能模块 4.2管理员功能模块 4.3商家功能模块 4.4用户前台功能模块 4.5骑手功能模块 5.相关说明 新鲜运行起来的项目:如需要源码数据库…

树形数据结构---堆

1.概念 什么是堆?堆和树的区别是什么?它的应用场景有哪些? 堆(Heap)是一种基于树形结构的数据结构,它是一种特殊的完全二叉树。堆的特点是每个节点都满足堆的性质,即父节点的键值总是大于或等…

8.删除有序数组中的重复项 II

文章目录 题目简介题目解答解法一:双指针(快慢指针)代码:复杂度分析: 题目链接 大家好,我是晓星航。今天为大家带来的是 删除有序数组中的重复项 II 相关的讲解!😀 题目简介 题目解…

[NSSRound#1 Basic]sql_by_sql

[NSSRound#1 Basic]sql_by_sql 这题没啥难的&#xff0c;二次注入盲注的套题 先注册&#xff0c;进去有个修改密码 可能是二次注入 修改密码处源码 <!-- update user set password%s where username%s; -->重新注册一个admin-- 获得admin身份&#xff08;原理看sqli-l…

【日常开发之FTP】Windows开启FTP、Java实现FTP文件上传下载

【日常开发之FTP】windows开启FTP、Java实现FTP文件上传下载 FTP前言FTP是什么&#xff1f;FTP两种模式 Windows开启FTPFTP windows 配置防火墙配置 Java部分Maven配置创建FTPClient 注意 FTP前言 FTP是什么&#xff1f; FTP是一个专门进行文件管理的操作服务&#xff0c;一般…

从github上复制代码,记录失败过程,最后放弃挣扎直接去github上手动下载了

从github中复制代码出现一下两种报错 1、Failed to connect to github.com port 443: 连接超时 2、Failed to connect to github.com port 443: 拒绝连接 解决办法如下&#xff1a; https://www.ipaddress.com/ github.com的IP地址为140.82.114.3 终端打开hosts文件加入140…