原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列4

在这里插入图片描述

文章目录

  • 原型网络进行分类的基本流程
  • 一、原始代码---计算欧氏距离,设计原型网络(计算原型+开始训练)
  • 二、每一行代码的详细解释
  • 总结


原型网络进行分类的基本流程

利用原型网络进行分类,基本流程如下:

1.对于每一个样本使用编码的方式fφ (),学习到每一个样本的编码表示(信息抽取)。
2.学习到每一个样本的编码表示之后,对于每一个分类下的所有的样本编码进行求和求取平均的操作,将结果作为分类的原型表示。
3.当一个新的数据样本被输入到网络中的时候,对于这个样本使用fφ(),生成其编码表示。
4.计算新的样本的编码表示和每一个分类的原型表示之间的距离情况,通过最下距离来确定查询样本属于哪一个分类。
5.在计算出所有的分类之间的距离之后,使用softmax的方式将距离转换成概率的形式。

一、原始代码—计算欧氏距离,设计原型网络(计算原型+开始训练)

def eucli_tensor(x,y):	#计算两个tensor的欧氏距离,用于loss的计算return -1*torch.sqrt(torch.sum((x-y)*(x-y))).view(1)class Protonets(object):def __init__(self,input_shape,outDim,Ns,Nq,Nc,log_data,step,trainval=False):#Ns:支持集数量,Nq:查询集数量,Nc:每次迭代所选类数,log_data:模型和类对应的中心所要储存的位置,step:若trainval==True则读取已训练的第step步的模型和中心,trainval:是否从新开始训练模型self.input_shape = input_shapeself.outDim = outDimself.batchSize = 1self.Ns = Nsself.Nq = Nqself.Nc = Ncif trainval == False:#若训练一个新的模型,初始化CNN和中心点self.center = {}self.model = CNNnet(input_shape,outDim)else:#否则加载CNN模型和中心点self.center = {}self.model = torch.load(log_data+'model_net_'+str(step)+'.pkl')		#'''修改,存储模型的文件名'''self.load_center(log_data+'model_center_'+str(step)+'.csv')	#'''修改,存储中心的文件名'''def compute_center(self,data_set):	#data_set是一个numpy对象,是某一个支持集,计算支持集对应的中心的点center = 0for i in range(self.Ns):data = np.reshape(data_set[i], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])data = Variable(torch.from_numpy(data))data = self.model(data)[0]	#将查询点嵌入另一个空间if i == 0:center = dataelse:center += datacenter /= self.Nsreturn centerdef train(self,labels_data,class_number):	#网络的训练#Select class indices for episodeclass_index = list(range(class_number))random.shuffle(class_index)choss_class_index = class_index[:self.Nc]#选20个类sample = {'xc':[],'xq':[]}for label in choss_class_index:D_set = labels_data[label]#从D_set随机取支持集和查询集support_set,query_set = self.randomSample(D_set)#计算中心点self.center[label] = self.compute_center(support_set)#将中心和查询集存储在list中sample['xc'].append(self.center[label])	#listsample['xq'].append(query_set)#优化器optimizer = torch.optim.Adam(self.model.parameters(),lr=0.001)optimizer.zero_grad()protonets_loss = self.loss(sample)protonets_loss.backward()optimizer.step()

二、每一行代码的详细解释

def eucli_tensor(x, y):return -1 * torch.sqrt(torch.sum((x - y) * (x - y))).view(1)

这是一个函数,用于计算两个张量(tensor)之间的欧氏距离(Euclidean Distance)。它通过计算两个张量差的平方和的平方根,并乘以-1。最后通过 view(1) 将结果转换成一个形状为 (1,) 的张量。

class Protonets(object):def __init__(self, input_shape, outDim, Ns, Nq, Nc, log_data, step, trainval=False):self.input_shape = input_shapeself.outDim = outDimself.batchSize = 1self.Ns = Nsself.Nq = Nqself.Nc = Ncif trainval == False:self.center = {}self.model = CNNnet(input_shape, outDim)else:self.center = {}self.model = torch.load(log_data + 'model_net_' + str(step) + '.pkl')self.load_center(log_data + 'model_center_' + str(step) + '.csv')

这是一个 Protonets 类的定义,它有一个构造函数 __init__,用于初始化类的属性。其中的参数含义如下:

  • input_shape:输入数据的形状。
  • outDim:输出维度。
  • Ns:支持集(support set)的数量。
  • Nq:查询集(query set)的数量。
  • Nc:每次迭代所选类别数。
  • log_data:模型和中心的存储位置。
  • step:训练的步数。
  • trainval:是否重新开始训练模型。

根据 trainval 的取值,分为两种情况进行初始化:

  1. trainval=False:表示训练一个新的模型。此时,初始化一个空的中心字典 self.center,并创建一个名为 CNNnet 的模型对象 self.model,其输入形状为 input_shape,输出维度为 outDim
  2. trainval=True:表示加载已经训练好的模型和中心。同样,初始化一个空的中心字典 self.center。然后通过 torch.load 加载之前训练保存的模型文件 log_data + 'model_net_' + str(step) + '.pkl',并将其赋给 self.model。接着调用 load_center 方法加载之前训练保存的中心文件 log_data + 'model_center_' + str(step) + '.csv'

总结

这段代码是一个用于实现 Protonets 算法的类。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/190864.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端JS 使用input完成文件上传操作,并对文件进行类型转换

使用input实现文件上传 // 定义一个用于文件上传的按钮<input type"file" name"upload1" />// accept属性用于定义允许上传的文件类型&#xff0c; onchange用于绑定文件上传之后的相应函数<input type"file" name"upload2"…

cadence virtuoso寄生参数提取问题

问题描述&#xff1a; 寄生参数提取的最后一步出现问题 calibre View generation encountered a fatal Error.Please consult the logfile for messages. 解决办法&#xff1a; sudo gedit /etc/profile&#xff08;如果失败就切换到超级用户root&#xff0c;使用su root命令…

为什么选择B+树作为数据库索引结构?

背景 首先&#xff0c;来谈谈B树。为什么要使用B树&#xff1f;我们需要明白以下两个事实&#xff1a; 【事实1】 不同容量的存储器&#xff0c;访问速度差异悬殊。以磁盘和内存为例&#xff0c;访问磁盘的时间大概是ms级的&#xff0c;访问内存的时间大概是ns级的。有个形象…

【数据库】数据库连接池导致系统吞吐量上不去-复盘

在实际的开发中&#xff0c;我们会使用数据库连接池&#xff0c;但是如果不能很好的理解其中的含义&#xff0c;那么就可以出现生产事故。 HikariPool-1 - Connection is not available, request timed out after 30001ms.当系统的调用量上去&#xff0c;就出现大量这样的连接…

IIs部署发布vue项目测试环境

打开【控制面板 > 程序>启用或关闭Windows功能 】 1、安装IIS: 把这些勾选上&#xff0c;点击确定下载。 2、安装.net: 把这些勾选上&#xff0c;点击确定下载。 3、搜索IIs打开&#xff1a; 4、右击【网站>添加网站 】进行配置&#xff0c;点击确定。 4、右击[项目le…

zabbix告警 邮件告警 钉钉告警

邮件告警添加主机组添加模板添加主机在模板中添加监控项在模板中添加触发器添加动作&#xff0c;远程执行命令给用户绑定告警媒介类型 钉钉告警安装python依赖模块python-requests配置钉钉告警配置脚本zabbix_ding.conf在目录/var/log/zabbix中创建钉钉告警日志文件zabbix_ding…

数据结构与算法设计分析——常用搜索算法

目录 一、穷举搜索二、图的遍历算法&#xff08;一&#xff09;深度优先搜索&#xff08;DFS&#xff09;&#xff08;二&#xff09;广度优先搜索&#xff08;BFS&#xff09; 三、回溯法&#xff08;一&#xff09;回溯法的定义&#xff08;二&#xff09;回溯法的应用 四、分…

SpringBoot 2.x 实战仿B站高性能后端项目

SpringBoot 2.x 实战仿B站高性能后端项目 下栽の地止&#xff1a;请看文章末尾 通常SpringBoot新建项目&#xff0c;默认是集成了Maven&#xff0c;然后所有内容都在一个主模块中。 如果项目架构稍微复杂一点&#xff0c;就需要用到Maven多模块。 本文简单概述一下&#xff0c…

【论文阅读】(CTGAN)Modeling Tabular data using Conditional GAN

论文地址&#xff1a;[1907.00503] Modeling Tabular data using Conditional GAN (arxiv.org) 摘要 对表格数据中行的概率分布进行建模并生成真实的合成数据是一项非常重要的任务&#xff0c;有着许多挑战。本文设计了CTGAN&#xff0c;使用条件生成器解决挑战。为了帮助进行公…

如何去掉图片上的水印?这三种去水印的方法帮你解决!

当我们从网上看到喜欢的图片&#xff0c;想要保存下来作为头像或者插入到工作汇报中时&#xff0c;却发现下载的图片带有水印。这不仅影响了图片的美观&#xff0c;还可能对图片的可用性造成影响。那么&#xff0c;如何去掉图片上的水印呢? 实际上&#xff0c;现在市面上的很多…

【每日一题】数位和相等数对的最大和

文章目录 Tag题目来源题目解读解题思路方法一&#xff1a;哈希表 写在最后 Tag 【哈希表】【数组】【2023-11-18】 题目来源 2342. 数位和相等数对的最大和 题目解读 在数组中找出数位和相等数对的和的最大值。 解题思路 方法一&#xff1a;哈希表 维护一个不同的数位和表…

【Spring篇】使用注解进行开发

&#x1f38a;专栏【Spring】 &#x1f354;喜欢的诗句&#xff1a;更喜岷山千里雪 三军过后尽开颜。 &#x1f386;音乐分享【如愿】 &#x1f970;欢迎并且感谢大家指出小吉的问题 文章目录 &#x1f33a;原代码&#xff08;无注解&#xff09;&#x1f384;加上注解⭐两个注…