pytorch实现胶囊网络(capsulenet)

胶囊网络在hinton刚提出来的时候小热过一段时间,之后热度并没有维持多久。vision transformer之后基本少有人问津了。不过这个模型思路挺独特的,值得研究一下。

这个模型的提出是为了解决CNN模型学习到的特征之间没有空间上的关系,从而对于各种变换不鲁棒的缺点。

模型的整体思路如下:

1,胶囊:

抛开论文里花哨的描述,胶囊其实就是特征图上比点更大的单元,本质上我觉得类似transformer的patch。当然也有一定的差别,因为后续要用动态路由更新胶囊,所以胶囊必须要是向量,而不是标量。

2,动态路由:

由于pooling会导致信息丢失,作者使用动态路由来连接两个胶囊层,并更新胶囊。

同时,动态路由也能建立不同层胶囊(特征)在空间上的相对关系。

由于胶囊其实是向量,动态路由算法会根据这些向量的相似性(点积)和一致性(加权)来决定信息传递的路径。

3,整体结构:

1)卷积层

2)PrimaryCaps层:这层的作用就是把卷积特征转变成胶囊的形式

3)DigitCaps层:用动态路由迭代生成高层的胶囊。

4)解码器

4,loss

胶囊网络的损失函数主要由两部分组成:间隔损失(Margin Loss)和重构损失。

在计算间隔损失时,会使用一个阈值(通常设置为0.9和0.1)来区分正样本和负样本。如果某一类的胶囊输出向量的模长大于阈值m+(正样本阈值,例如0.9),则认为该类存在,并将其视为正样本;反之,如果输出向量的模长小于阈值m-(负样本阈值,例如0.1),则认为该类不存在,将其视为负样本。

重构损失的计算通常基于原始输入数据与重构数据之间的差异,例如使用均方误差(MSE)来衡量这种差异。

如果站在2024年的如今再来看当初的设计,其实胶囊的思路还是很像后来的transformer的,有点殊途同归的感觉。


pytorch实现:

1,实现初始胶囊

首先是会用到的压缩函数,压缩函数的作用是将向量的长度压缩到0和1之间,同时保留向量的方向不变。

公式:

def squash(inputs, axis=-1):norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)scale = norm**2 / (1 + norm**2 + 1e-8) / (norm + 1e-8)return scale * inputs

初始胶囊,这一层的作用是将卷积特征转换为胶囊的形式。

class PrimaryCapsule(nn.Module):def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride=1, padding=0):super(PrimaryCapsule, self).__init__()self.dim_caps = dim_capsself.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)def forward(self, x):outputs = self.conv2d(x)outputs = outputs.reshape(x.size(0), -1, self.dim_caps)return squash(outputs)

2,实现胶囊层

路由算法

这个伪代码初看起来挺乱的,我翻译成人话如下:

首先,每一次迭代由两层胶囊层做点积后再通过softmax计算出耦合系数c。

耦合系数和下层胶囊的预测计算加权和,这是个投票的过程。

再通过压缩函数,就得到了本层的胶囊v。

因为这是个迭代的过程,需要不断更新耦合系数C。

新的耦合系数由两层胶囊之间的相似度决定。


具体实现中,会对低层胶囊先做一个变换,也就是下面代码里的weight。这个权重矩阵代表的是对下层胶囊的变化,变换之后的结果Ui|j用论文里的话说叫做“prediction vectors”。

胶囊层代码:

class DenseCapsule(nn.Module):def __init__(self, in_num_caps, in_dim_caps, out_num_caps, out_dim_caps, routings=3):super(DenseCapsule, self).__init__()self.in_num_caps = in_num_capsself.in_dim_caps = in_dim_capsself.out_num_caps = out_num_capsself.out_dim_caps = out_dim_capsself.routings = routings #路由的迭代次数#初始化self.weight = nn.Parameter(0.01 * torch.randn(out_num_caps, in_num_caps, out_dim_caps, in_dim_caps))def forward(self, x):u_hat = torch.squeeze(torch.matmul(self.weight, x[:, None, :, :, None]), dim=-1)#从当前计算图中分离出x_hat,这样在后续的反向传播中不会计算其梯度 u_hat_detached = u_hat.detach()b = torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps).cuda()#路由算法for i in range(self.routings):c = F.softmax(b, dim=1)if i == self.routings - 1:v = squash(torch.sum(c[:, :, :, None] * u_hat, dim=-2, keepdim=True))else:v = squash(torch.sum(c[:, :, :, None] * u_hat_detached, dim=-2, keepdim=True))b = b + torch.sum(v * u_hat_detached, dim=-1)return torch.squeeze(v, dim=-2)

需要将的是u_hat_detached = u_hat.detach()这一步。将u_hat从计算图中分离出来的目的,是为了防止迭代过程中梯度不断累积,导致梯度过大。所以我们可以在后续的路由算法中看出,只有在最后一次计算路由时使用了u_hat,之前的迭代中都是使用的u_hat_detached。从而让整个路由过程中梯度只更新一次。

3,损失函数

def caps_loss(y_true, y_pred, x, x_recon, lambd=0.5):L = y_true * torch.clamp(0.9 - y_pred, min=0.) ** 2 + 0.5 * (1 - y_true) * torch.clamp(y_pred - 0.1, min=0.) ** 2L_margin = L.sum(dim=1).mean()L_recon = nn.MSELoss()(x_recon, x)return L_margin + lambd * L_recon

4,整体模型

模型返回两个值,一个是预测的概率,一个是重建的图像。这两个值会分别用来计算间隔损失和重构损失。

class CapsuleNet(nn.Module):def __init__(self, input_size, classes, routings):super(CapsuleNet, self).__init__()self.input_size = input_sizeself.classes = classesself.routings = routingsself.conv1 = nn.Conv2d(input_size[0], 256, kernel_size=9, stride=1, padding=0)self.primarycaps = PrimaryCapsule(256, 256, 8, kernel_size=9, stride=2, padding=0)self.digitcaps = DenseCapsule(in_num_caps=32*6*6, in_dim_caps=8,out_num_caps=classes, out_dim_caps=16, routings=routings)self.decoder = nn.Sequential(nn.Linear(16*classes, 512),nn.ReLU(inplace=True),nn.Linear(512, 1024),nn.ReLU(inplace=True),nn.Linear(1024, input_size[0] * input_size[1] * input_size[2]),nn.Sigmoid())self.relu = nn.ReLU()def forward(self, x, y=None):x = self.relu(self.conv1(x))x = self.primarycaps(x)x = self.digitcaps(x)length = x.norm(dim=-1)if y is None:index = length.max(dim=1)[1]y = torch.zeros(length.size()).scatter_(1, index.view(-1, 1), 1.)reconstruction = self.decoder((x * y[:, :, None]).view(x.size(0), -1))return length, reconstruction.view(-1, *self.input_size)

5,注意事项:

1)one-hot

在重建过程中使用的标签y是one-hot形式的,因此在训练和测试时需要加上这行代码,转换一下

targets = F.one_hot(targets, num_classes=classes).to(device)

2) loss

训练和测试时的loss设置如下

loss = caps_loss(y_true=targets,y_pred=y_pred,x=imgs,x_recon=x_recon,lambd=0.5)loss = loss.to(device)

其中lambd这个系数决定的是重构损失所占的比例 loss=margin_loss+lambd*recon_loss

总结:

胶囊网络分类结果不算差,在我的一些任务中train from scratch的胶囊网络就超越了imagenet1k上预训练过再finetune的vit。也超过了无预训练的VGG和resnet。(但是不如预训练过的vgg和resnet)。

这样的表现放在2017年已经很能打了,没火的原因我感觉有3个:

首先,由于胶囊网络迭代过程需要多次完整的特征图点乘特征图,所以内存消耗和时间消耗都是巨大的。我跑256的图时,24g显存的4090也只能把batch设置成5,运行速度非常慢。放在2017年,只能用1080ti来跑这个模型,简直折磨。(我2018年时也试过这个模型,训练都是按周算的,这谁愿意用啊)

另外一个原因可能是它的改进潜力不大。例如vit的核心机制是自注意力,注意力大家都玩出花来了,各种改进思路都很好借鉴。虽然vit效果很一般,但是后续的改进模型一个比一个厉害。而胶囊网络的核心路由算法想要创新就比较难。

最后还有一点就是原作者没放出胶囊网络在imagenet上的预训练模型。这个对模型热度的影响其实挺大的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/611053.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

政安晨:【Keras机器学习实践要点】(二十六)—— 内卷神经网络

目录 简介 设置 卷积 演变 测试逆卷积层 图像分类 获取 CIFAR10 数据集 数据可视化 卷积神经网络 逆向传播神经网络 比较 损失图和准确率图 可视化卷积核 结论 政安晨的个人主页:政安晨 欢迎 👍点赞✍评论⭐收藏 收录专栏: TensorFlow与Ke…

open c UF_MODL_create_simple_hole 识别放置平面 UF_MODL_ask_face_data

在BLOCK上创建一个简单孔 UF_FEATURE_SIGN sign UF_NULLSIGN;double block_orig[3] { -25.0,-25.0,0.0 };char* block_len[3] { "50","50","30" };tag_t blk_obj;UF_MODL_create_block1(sign, block_orig, block_len, &blk_obj);tag_t bo…

LabVIEW无线快速存取记录器(WQAR)测试平台

LabVIEW无线快速存取记录器(WQAR)测试平台 随着民用航空业的迅速发展,航空安全的保障日益成为公众和专业领域的关注焦点。无线快速存取记录器(WirelessQuick Access Recorder, WQAR)作为记录飞行数据、监控飞行品质的…

HarmonyOS NEXT应用开发—在Native侧实现进度通知功能

介绍 本示例通过模拟下载场景介绍如何将Native的进度信息实时同步到ArkTS侧。 效果图预览 使用说明 点击“Start Download“按钮后,Native侧启动子线程模拟下载任务Native侧启动子线程模拟下载,并通过Arkts的回调函数将进度信息实时传递到Arkts侧 实…

Redis入门到通过之Redis安装

文章目录 Redis安装说明1.单机安装Redis1.1.安装Redis依赖1.2.上传安装包并解压1.3.启动1.3.1.默认启动1.3.2.指定配置启动1.3.3.开机自启 2.Redis客户端2.1.Redis命令行客户端2.2.图形化桌面客户端2.2.1.安装2.2.2.建立连接 Redis安装说明 大多数企业都是基于Linux服务器来部…

详解小度Wi-Fi内部芯片及电路原理图分析

小度随身WiFi是一款便携式USB路由器,它实现了用户跨终端联网,随身携带,可以在室内实现免费WiFi覆盖。外形美观,小巧便携。 这一款小度WiFi采用的主芯片是MT7601UN,一款高度集成的Wi-Fi单芯片,支持150 Mbp…

HarmonyOS实战开发-图片编辑、使用 TextArea 实现多文本输入

介绍 本示例使用 TextArea 实现多文本输入,使用 ohos.app.ability.common 依赖系统的图库引用,实现在相册中获取图片,使用 ohos.multimedia.image 生成pixelMap,使用pixelMap的scale(),crop(),rotate()接口…

计算两个时间段的差值

计算两个时间段的差值 运行效果&#xff1a; 代码实现&#xff1a; #include<stdio.h>typedef struct {int h; // 时int m; // 分int s; // 秒 }Time;void fun(Time T[2], Time& diff) {int sum_s[2] { 0 }; for (int i 0; i < 1; i) { // 统一为秒数sum_s[…

pandas(day10)

一. 各各品类产品交易指数对比 获取文件名 files glob.glob("./*.xlsx")# 读取数据&#xff0c;并改列名&#xff0c;增加一列 品牌 dfs [] for f in files:t f[2:4]df pd.read_excel(f)df["品牌"] tif t "拜耳":df.rename(columns{"…

在线课程平台LearnDash评测 – 最佳 WordPress LMS插件

在我的LearnDash评测中&#xff0c;我探索了流行的 WordPress LMS 插件&#xff0c;该插件以其用户友好的拖放课程构建器而闻名。我深入研究了各种功能&#xff0c;包括课程创建、测验、作业、滴灌内容、焦点模式、报告、分析和管理工具。 我的评测还讨论了套餐和定价选项&…

视频基础学习五——视频编码基础二(编码参数帧、GOP、码率等)

系列文章目录 视频基础学习一——色立体、三原色以及像素 视频基础学习二——图像深度与格式&#xff08;RGB与YUV&#xff09; 视频基础学习三——视频帧率、码率与分辨率 视频基础学习四——视频编码基础一&#xff08;冗余信息&#xff09; 视频基础学习五——视频编码基础…

基于Springboot的笔记记录分享网站(有报告)。Javaee项目,springboot项目。

演示视频&#xff1a; 基于Springboot的笔记记录分享网站&#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构…