《动手学深度学习(PyTorch版)》笔记7.3

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。

Chapter7 Modern Convolutional Neural Networks

7.3 Network in Networks: NiN

在这里插入图片描述

NiN的想法是在每个像素位置(针对每个高度和宽度)应用一个全连接层。如果我们将权重连接到每个空间位置,我们可以将其视为 1 × 1 1\times 1 1×1卷积层,或作为在每个像素位置上独立作用的全连接层。从另一个角度看,即将空间维度中的每个像素视为单个样本,将通道维度视为不同特征(feature)。NiN块以一个普通卷积层开始,后面是两个 1 × 1 1 \times 1 1×1的卷积层(充当带有ReLU激活函数的逐像素全连接层)。NiN和AlexNet之间的一个显著区别是NiN完全取消了全连接层而使用NiN块,其输出通道数等于标签类别的数量,且在最后用一个全局平均汇聚层(global average pooling layer),生成一个对数几率(logits)。NiN设计的一个优点是显著减少了模型所需参数的数量,然而在实践中,这种设计有时会增加训练模型的时间。

import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as pltdef nin_block(in_channels, out_channels, kernel_size, strides, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())net = nn.Sequential(nin_block(1, 96, kernel_size=11, strides=4, padding=0),nn.MaxPool2d(3, stride=2),nin_block(96, 256, kernel_size=5, strides=1, padding=2),nn.MaxPool2d(3, stride=2),nin_block(256, 384, kernel_size=3, strides=1, padding=1),nn.MaxPool2d(3, stride=2),nn.Dropout(0.5),# 标签类别数是10,设置输出通道为10使得输出通道数等于标签数nin_block(384, 10, kernel_size=3, strides=1, padding=1),nn.AdaptiveAvgPool2d((1, 1)),#全局平均汇聚后每个输出值都代表了对应通道整个输入的平均值,这导致网络更加关注图像中重要的特征,而不关注它们的位置nn.Flatten())#将四维的输出转成二维的输出,其形状为(批量大小,10)X = torch.rand(size=(1, 1, 224, 224))
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t', X.shape)#训练
lr, num_epochs, batch_size = 0.1, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

训练结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/455058.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue3自定义PostCss插件

Vue3自定义PostCss插件 插件功能: 实现自动转px为vw功能 1. 创建插件ts文件2. tsconfig.node.json引入插件3. vite.config.ts增加插件配置4. 编写插件内容5. 示例 插件功能: 实现自动转px为vw功能 px 固定单位,不会随着屏幕的变化而变化 vh vw 相对于视口高宽进行控制 1. 创建…

LeetCode541. 反转字符串 II

541. 反转字符串 II 给定一个字符串 s 和一个整数 k,从字符串开头算起,每计数至 2k 个字符,就反转这 2k 字符中的前 k 个字符。 如果剩余字符少于 k 个,则将剩余字符全部反转。如果剩余字符小于 2k 但大于或等于 k 个&#xff0…

[UI5 常用控件] 07.SplitApp,SplitContainer

文章目录 前言1. SplitApp1.1 组件结构1.2 Demo1.3 mode属性 2. SplitContainer 前言 本章节记录常用控件SplitApp,SplitContainer。主要功能是在左侧显示Master页面,右侧显示Detail页面。 Master页面和Detail页面可以由多个Page组成,并支持…

BEV感知算法学习

BEV感知算法学习 3D目标检测系列 Mono3D(Monocular 3D Object Detection for Autonomous Driving) 流程: 通过在地平面上假设先验,在3D空间中对具有典型物理尺寸的候选边界框进行采样;然后我们将这些方框投影到图像平面上,从而避…

SpringBoot实战第三天

今天主要完成了: 新增棋子分类 棋子分类列表 获取棋子分类详情 更新棋子分类 更新棋子分类和添加棋子分类_分组校验 新增棋子 新增棋子参数校验 棋子分类列表查询(条件分页) 先给出分类实体类 Data public class Category {private Integer id;//主键IDNot…

Spring- FactoryBean接口中的getObject()方法

目录 一、Spring框架介绍 二、FactoryBean接口是什么 三、getObject()方法如何使用 一、Spring框架介绍 Spring框架是一个轻量级的、非侵入式的Java企业级应用开发框架,以IoC(控制反转)和AOP(面向切面编程)为核心思…

阿里云服务器租用费用_2024年2月最新价格表

2024年2月阿里云服务器租用价格表更新,云服务器ECS经济型e实例2核2G、3M固定带宽99元一年、ECS u1实例2核4G、5M固定带宽、80G ESSD Entry盘优惠价格199元一年,轻量应用服务器2核2G3M带宽轻量服务器一年61元、2核4G4M带宽轻量服务器一年165元12个月、2核…

百面嵌入式专栏(技能篇)嵌入式技能树详解

沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇我们将介绍嵌入式重点知识。 一、C语言 C语言这一块的高频考点有预处理、关键字、数据类型、指针与内存管理。 预处理有文件包含、宏定义、条件编译,其中最重要的是宏定义,通常考核宏定义的语法、宏替换与函数的区…

Redis核心技术与实战【学习笔记】 - 24.Redis 数据分片方案选择:Codis 和 Redis Cluster

简述 Redis 的切片集群使用多个实例保存数据,能很好的应对大数据量的场景。在《4.Redis 切片集群》中,介绍了 Redis 官方提供的切片集群方法 Redis Cluster。本章,再来学习下,在 Redis Cluster 方案正式发布前,业界广…

算法练习-四数之和(思路+流程图+代码)

难度参考 难度:中等 分类:数组 难度与分类由我所参与的培训课程提供,但需要注意的是,难度与分类仅供参考。且所在课程未提供测试平台,故实现代码主要为自行测试的那种,以下内容均为个人笔记,旨在…

推动海外云手机发展的几个因素

随着科技的不断发展,海外云手机作为一种新兴技术,在未来呈现出令人瞩目的发展趋势。本文将在用户需求、技术创新和全球市场前景等方面,探讨海外云手机在未来的发展。 1. 用户需求的引领: 随着人们对移动性和便捷性的需求不断增长&…