无脑入门pytorch系列(三)—— nn.Linear

本系列教程适用于没有任何pytorch的同学(简单的python语法还是要的),从代码的表层出发挖掘代码的深层含义,理解具体的意思和内涵。pytorch的很多函数看着非常简单,但是其中包含了很多内容,不了解其中的意思就只能【看懂代码】,无法【理解代码】。

目录

  • 官方定义
  • demo1
  • demo2

官方定义

nn.Linear 是 PyTorch 中用于创建线性层的类。线性层也被称为全连接层,它将输入与权重矩阵相乘并加上偏置,然后通过激活函数进行非线性变换。

官方的文档如下,torch.nn.Linear:

image-20230811162434032

demo1

下面是一个官方文档给出的例子:

m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)
print(output.size())

输出的结果:

image-20230811162456557

首先,输出[128, 20]的张量,经过一个[20, 30]的线性层,变成[128, 30]的张量。
可以理解为矩阵的乘法,也就是矩阵的"外积",矩阵的叉乘,第一个矩阵的行数与第二个矩阵的列数相同。

demo2

input_data = torch.Tensor([[1, 2, 3], [4, 5, 6]]) # [2, 3] 
m = nn.Linear(3, 2)
output = m(input_data)
print(output) # [2, 2]

输出:

image-20230811162529522

可以看看nn.Linear(3, 2)的参数:

for param in m.parameters():print(param)

输出:

image-20230811162606663

结合参数,其实本身它们的计算就是矩阵的乘法:

image-20230811162640687

输入X为[n, i]的矩阵,经过W为[i,0]的矩阵,加上b的偏置得到Y为[n,o]的矩阵。

计算的思路也比较简单:
output[0][0] = [1, 2, 3] * [0.2888, -0.4596, -0,4896] + 0.3740 = -1.7253
output[0][1] = [1, 2, 3] * [0.4730, -0.4033, -0.4739] + 0.3182 = -1.4370
output[1][0] = [4, 5, 6] * [0.2888, -0.4596, -0,4896] + 0.3740 = -3.7066
output[1][1] = [4, 5, 6] * [0.4730, -0.4033, -0.4739] + 0.3182 = -2.6495

通过input和param的对比,我们可以很轻松地理解实际上就是矩阵的乘法操作。而模型在训练过程中就是不断调整param的参数使得输出的张量符合训练集的需求。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/67927.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Vue前端】设置标题用于SEO优化

原文作者:我辈李想 版权声明:文章原创,转载时请务必加上原文超链接、作者信息和本声明。 文章目录 1.vue全局配置2.创建并暴露getPageTitle方法3.通过全局前置守卫设置title4.页面上引用title5.项目使用中英文翻译,title失效 1.vu…

K8S系列三:单服务部署

写在前面 本文是K8S系列第三篇,主要面向对K8S新手同学,阅读本文需要读者对K8S的基本概念,比如Pod、Deployment、Service、Namespace等基础概念有所了解。尚且不熟悉的同学推荐先阅读本系列的第一篇文章《K8S系列一:概念入门》[1]…

mysql窗口函数

MySQl 8.0 窗口函数 窗口函数适用场景: 对分组统计结果中的每一条记录进行计算的场景下, 使用窗口函数更好, 注意, 是每一条!! 因为MySQL的普通聚合函数的结果(如 group by)是每一组只有一条记录!!! 可以跟Hive的对比着看: 点我, 特么的花了一晚上整理, 没想到跟Hive 的基本一致…

ntfy Delphi 相关消息接口文档

关联文档: ntfy 实现消息订阅和通知(无需注册、无需服务器,太好了)_海纳老吴的博客-CSDN博客群晖 nas 自建 ntfy 通知服务(梦寐以求)_海纳老吴的博客-CSDN博客 目录 一、消息实体对象接口 1. 消息发布方…

【MybatisPlus】LambdaQueryWrapper和QueryWapper的区别

个人主页:金鳞踏雨 个人简介:大家好,我是金鳞,一个初出茅庐的Java小白 目前状况:22届普通本科毕业生,几经波折了,现在任职于一家国内大型知名日化公司,从事Java开发工作 我的博客&am…

自定义 视频/音频 进度条

复制代码根据自己需求改动就可以了 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" /><metaname"viewport"conten…

如何保证数据传输的安全?

要确保数据传输的安全&#xff0c;您可以采取以下措施&#xff1a; 使用加密协议&#xff1a;使用安全的传输协议&#xff0c;如HTTPS(HTTP over SSL/TLS)或其他安全协议&#xff0c;以保护数据在传输过程中的安全性。加密协议可以有效防止数据被窃听或篡改。 强化身份验证&…

途乐证券-越跌越买!港股又回调,资金却坚定买入,行情何时到来?

港股重复震动。 8月14日&#xff0c;港股三大指数再次跌落&#xff0c;截至收盘&#xff0c;恒生指数跌1.58%&#xff0c;报18773.55点&#xff0c;恒生科技指数跌1.52%&#xff0c;国企指数跌1.79%。 进入8月以来&#xff0c;港股继续跌落&#xff0c;恒生指数单月跌落6.5%&a…

Collada .dae模型格式简明教程

当你从互联网下载 3D 模型时&#xff0c;可能会在格式列表中看到 .dae 格式。 它是什么&#xff1f; 推荐&#xff1a;用 NSDT编辑器 快速搭建可编程3D场景。 1、Collada DAE概述 COLLADA是COLLAborative Design Activity&#xff08;中文&#xff1a;协作设计活动&#xff0…

vue3+vite配置vantUI主题

❓在项目中统一配置UI主题色&#xff0c;各个组件配色统一修改 vantUI按需安装 参考vantUI文档 创建vantVar.less文件夹进行样式编写 vantVar.less :root:root{//导航--van-nav-bar-height: 44px;//按钮--van-button-primary-color: #ffffff;--van-button-primary-backgr…

Python web实战之Django的AJAX支持详解

关键词&#xff1a;Web开发、Django、AJAX、前端交互、动态网页 今天和大家分享Django的AJAX支持。AJAX可实现在网页上动态加载内容、无刷新更新数据的需求。 1. AJAX简介 AJAX&#xff08;Asynchronous JavaScript and XML&#xff09;是一种在网页上实现异步通信的技术。通过…

Matlab图坐标轴数值负号改为减号(change the hyphen (-) into minus sign (−, “U+2212”))

在MATLAB中&#xff0c;坐标轴负数默认符号是 - &#xff0c;如下图所示 x 1:1:50; y sin(x); plot(x,y)可通过以下两语句将负号修改为减号&#xff1a; set(gca,defaultAxesTickLabelInterpreter,latex); yticklabels(strrep(yticklabels,-,$-$));或者 set(gca, TickLabe…