3.7.1 初始化模型参数

news/2025/3/10 15:54:21/文章来源:https://www.cnblogs.com/dingxingdi/p/18762869

nn.Linear不是可以自动展平吗?为什么还要添加nn.Flatten()?实际上,这两者的展平是不同的,前者的展平主要用在Seq2Seq里面,是最后一维不同,前两维合并,而后者的展平是第一维不同,后两维合并。具体用法如下
在 PyTorch 中,nn.Flatten() 是一个用于将张量(Tensor)展平为一维向量的层。它的主要作用是将多维的张量转换为适合全连接层(Fully Connected Layer)处理的一维形式。以下是其详细说明:


作用

  1. 展平张量

    • 将输入张量的除 batch 维度外的其他维度合并为一个维度。
    • 例如,输入形状为 (batch_size, C, H, W) 的图像张量,经过 Flatten() 后会变成 (batch_size, C*H*W)
  2. 简化模型定义

    • 在神经网络中,通常在卷积层(Convolutional Layer)之后需要将特征图(feature maps)展平为一维向量,以便输入到全连接层(Dense Layer)。Flatten() 提供了一个简洁的方式实现这一操作。

参数

nn.Flatten() 可以接受两个可选参数:

  • start_dim:从哪个维度开始展平(默认为 1,即从 batch 维度之后的第一个维度开始)。
  • end_dim:展平到哪个维度(默认为 -1,即展平到最后一个维度)。

示例参数说明

  • Flatten(start_dim=1, end_dim=-1):默认行为,展平所有维度(除 batch 维度外)。
  • Flatten(start_dim=2):从第 2 维(假设输入是 (B, C, H, W),则从 H 开始展平)。
  • Flatten(start_dim=1, end_dim=2):展平 CH 维度,保留 W 维度。

使用方法

1. 基本用法

import torch
import torch.nn as nn# 定义一个包含 Flatten 层的模型
model = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3),  # 卷积层nn.ReLU(),nn.MaxPool2d(2),nn.Flatten(),  # 展平层nn.Linear(16 * 14 * 14, 10)  # 全连接层
)# 输入示例:假设输入图像形状为 (batch_size=1, channels=3, height=28, width=28)
x = torch.randn(1, 3, 28, 28)
output = model(x)
print(output.shape)  # 输出形状为 (1, 10)

2. 自定义展平范围

# 展平从第 2 维度开始到最后一个维度
flatten_layer = nn.Flatten(start_dim=2)
x = torch.randn(2, 3, 4, 5)  # 输入形状为 (2, 3, 4, 5)
y = flatten_layer(x)  # 输出形状为 (2, 3, 20)(4*5=20)

为什么需要 Flatten?

在神经网络中,常见的场景如下:

  1. 卷积层 → 全连接层

    • 卷积层的输出通常是 (batch_size, channels, height, width) 的 4D 张量。
    • 全连接层需要输入为 (batch_size, features) 的 2D 张量,因此需要展平。
  2. 避免手动计算维度

    • 手动计算展平后的维度(如 channels * height * width)容易出错,而 Flatten() 可自动处理。

Flatten 与 Reshape 的区别

  • Flatten

    • 是一个 PyTorch 层(Layer),直接嵌入在模型中。
    • 自动计算展平后的维度,无需手动指定目标形状。
    • 适用于模型定义中的动态展平。
  • reshape

    • 是张量的 方法(如 tensor.reshape(-1)),需要手动指定目标形状。
    • 需要明确知道展平后的维度,否则可能导致形状错误。
    • 不属于模型的一部分,通常用于数据预处理。

示例对比

# 使用 Flatten 层
x = torch.randn(1, 3, 28, 28)
model = nn.Sequential(nn.Flatten(),nn.Linear(3*28*28, 10)
)
output = model(x)  # 自动计算展平后的维度# 使用 reshape
x_flattened = x.reshape(x.shape[0], -1)  # 需要手动指定目标形状
linear = nn.Linear(3*28*28, 10)
output = linear(x_flattened)  # 需要手动计算维度

常见问题

  1. 输入已经是 2D,展平后会怎样?

    • 如果输入已经是 2D(如 (batch_size, features)),Flatten() 不会改变其形状。
  2. 如何处理动态输入形状?

    • Flatten() 可以自动处理不同 batch_size 或动态输入形状,无需手动调整。
  3. Flatten 是否影响梯度?

    • 不影响。展平操作是线性变换,梯度会正确反向传播。

总结

  • 作用:将多维张量展平为一维(保留 batch 维度)。
  • 适用场景:卷积层与全连接层之间,简化模型定义。
  • 参数:通过 start_dimend_dim 自定义展平范围。
  • 优势:自动处理维度计算,避免手动 reshape 的繁琐。

通过 nn.Flatten(),你可以更高效、简洁地构建复杂的神经网络模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/896850.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

信创国产系统对国产芯片产业的推动作用

信创国产系统与国产芯片产业紧密相连,二者的协同发展对于我国信息技术产业的自主可控和安全稳定具有至关重要的意义。信创国产系统的崛起,正以一种前所未有的力量推动着国产芯片产业不断向前迈进,在提升产业竞争力、保障国家信息安全等方面发挥着不可忽视的作用。 创造市场需…

Apache DolphinScheduler项目2月份进展总结

各位热爱DolphinScheduler的小伙伴们,今年2月份的社区月报如期而至,更新了DolphinScheduler项目和社区在本月的重大进展,请查收! DolphinScheduler月度Merge Stars 感谢以下小伙伴在2025年2月期间为Apache DolphinScheduler社区做的精彩贡献(排名不分先后):@ruanwenjun,…

信创国产系统实施后的效果评估与改进方法

信创国产系统的实施是推动信息技术自主创新、保障国家信息安全的重要举措。随着信创国产系统在各个领域的广泛应用,对其实施后的效果进行科学评估并持续改进显得尤为关键。这不仅关系到系统能否稳定、高效运行,更关乎企业和国家在数字化转型过程中的战略布局与发展。通过合理…

3.10 lambda算法

1.1 表达式(expression)(可以把两个表达式写在一起组成一个新的表达式) 包含:变量(单个字母/多个字母);括号(表示是一个整体);λ和.描述函数(函数由λ和变量开头,然后是一个.,然后是表达式),λ没有特殊的含义,只是说函数由此开始,在λ后面,.前面的字母成为变…

pfastq-dump 软件的安装以及测试

pfastq-dump 软件的安装以及测试 001、官网:https://github.com/inutano/pfastq-dump002、下载最新版wget -c https://github.com/inutano/pfastq-dump/archive/refs/tags/v0.1.6.tar.gztar -xzvf pfastq-dump-0.1.6.tar.gzcd pfastq-dump-0.1.6/bin/chmod +x pfastq-dump 00…

7.9K star!跨平台开发从未如此简单,这个开源框架让APP开发效率飙升!

Lynx 是一个革命性的跨平台开发框架,使用 TypeScript 开发即可同时构建 iOS、Android 和 Web 应用。通过创新的布局引擎和原生渲染技术,让开发者用一套代码实现三端同屏效果,大大提升整体的开发效率!嗨,大家好,我是小华同学,关注我们获得“最新、最全、最优质”开源项目…

国内头部HR SaaS厂商的薪酬管理实践:以标准化功能满足复杂薪酬管理需求

易路的成功案例证明了其在薪酬数字化管理转型中的领导地位,为其他企业提供了宝贵的参考和启示。随着易路的不断创新和优化,我们有理由相信它将继续引领行业,帮助企业实现薪酬管理的战略性业务支撑,为企业在激烈的市场竞争中提供强大的人力资源支持,实现企业与员工的共同发…

源码安装Rpcapd,用于 wireshark 远程抓包

背景 libpcap 是一个基础且关键的网络数据包捕获库,为 Wireshark、tcpdump 等流行工具提供核心功能支持。其中,rpcapd(Remote Packet Capture Daemon)组件允许在远程系统上进行数据包捕获,这一功能让我们能够从一个中心位置监控多个远程网络接入点,而无需在每个监控点都部署…

3.10 计数基础排列与组合

1.1 基本计数原则:乘积法则 1.1.1总共有多少种不同的长度为7的位串(位串:可视为一个数组,长度为7) A:2^7=128 1.1.2 计数有穷集的子集|S|表示长度;幂集:幂集(Power Set)是集合论中的一个基本概念。给定一个集合 S,其幂集 P(S) 是包含 S 所有子集的集合,包括空集和 S…

Nginx 常用功能,反向代理笔记

前言 本文是runoob教程的搬运,稍微修改了原文中的一些错误拼写的问题,顺便对一些概念进行了更详细的解释,欢迎批评指正!Nginx常用功能Http代理,反向代理:作为web服务器最常用的功能之一,尤其是反向代理。 这里我给来2张图,对正向代理与反向代理做个诠释,具体细节,大家…

Oracle 19c 数据库实战:从单机部署到 DG 高可用架构搭建

前言:在当今数字化时代,数据已成为企业最宝贵的资产之一。而数据库作为数据存储和管理的核心工具,其重要性不言而喻。Oracle 数据库作为全球领先的商业数据库管理系统,以其卓越的性能、可靠性和强大的功能,广泛应用于企业的关键业务系统中。无论是大型企业的 ERP、CRM 系统…

002TypeScript开发实战

如果读取不到,情况下: 1、建好项目后我们在这里写一个ts语法,让项目跑起来npm run dev 2、在src中新建文件demo.vue