原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列8

在这里插入图片描述

文章目录

  • 前言
  • 一、原始代码
  • 二、对每一行代码的解释:
  • 总结


前言

这是该系列原型网络的最后一段代码及其详细解释,感谢各位的阅读!


一、原始代码

if __name__ == '__main__':##载入数据labels_trainData, labels_testData = load_data()  # labels_trainData是字典,是key:value形式class_number_train = max(list(labels_trainData.keys())) #963class_number_test = max(list(labels_testData.keys())) #658wide = labels_trainData[0][0].shape[0]  # 105      #二维张量,shape[0]代表行数,shape[1]代表列数length = labels_trainData[0][0].shape[1]  # 105for label in labels_trainData.keys():labels_trainData[label] = np.reshape(labels_trainData[label], [-1, 1, wide, length])for label in labels_testData.keys():labels_testData[label] = np.reshape(labels_testData[label], [-1, 1, wide, length])##初始化模型protonets = Protonets((1, wide, length), 10, 5, 5, 60, './log/', 50)  # '''根据需求修改类的初始化参数,参数含义见protonets_net.py'''##训练prototypical_networkfor n in range(100):  ##随机选取x个类进行一个episode的训练protonets.train(labels_trainData, class_number_train)if n % 2 == 0 and n != 0:  # 每5次存储一次模型,并测试模型的准确率,训练集的准确率和测试集的准确率被存储在model_step_eval.txt中torch.save(protonets.model, './log/model_net_' + str(n) + '.pkl')protonets.save_center('./log/model_center_' + str(n) + '.csv')test_accury = protonets.evaluation_model(labels_testData, class_number_test)print(test_accury)str_data = str(n) + ',' + str('       test_accury     ') + str(test_accury) + '\n'with open('./log/model_step_eval.txt', "a") as f:f.write(str_data)print(n)

二、对每一行代码的解释:

  1. if __name__ == '__main__':
    这是一个Python的惯用写法,表示当脚本直接被运行时(而不是被作为模块导入时),才会执行下面的代码块。

  2. labels_trainData, labels_testData = load_data()
    调用 load_data() 函数加载数据,并将返回的标签训练数据和标签测试数据保存到 labels_trainDatalabels_testData 变量中。

  3. class_number_train = max(list(labels_trainData.keys()))
    获取标签训练数据中的最大键(即最大类别数),并将其保存到 class_number_train 变量中。

  4. class_number_test = max(list(labels_testData.keys()))
    获取标签测试数据中的最大键(即最大类别数),并将其保存到 class_number_test 变量中。

  5. wide = labels_trainData[0][0].shape[0]
    获取标签训练数据中第一个样本的宽度,并将其保存到 wide 变量中。

  6. length = labels_trainData[0][0].shape[1]
    获取标签训练数据中第一个样本的长度,并将其保存到 length 变量中。

  7. for label in labels_trainData.keys():
    遍历标签训练数据中的所有键。

  8. labels_trainData[label] = np.reshape(labels_trainData[label], [-1, 1, wide, length])
    对每个标签训练数据进行重塑,将其形状改为 [-1, 1, wide, length],其中 -1 表示自动计算维度大小。

  9. for label in labels_testData.keys():
    遍历标签测试数据中的所有键。

  10. labels_testData[label] = np.reshape(labels_testData[label], [-1, 1, wide, length])
    对每个标签测试数据进行重塑,将其形状改为 [-1, 1, wide, length]

  11. protonets = Protonets((1, wide, length), 10, 5, 5, 60, './log/', 50)
    创建一个 Protonets 类的实例,传入模型的初始化参数。

  12. for n in range(100):
    从0到99的循环中,执行以下代码块。

  13. protonets.train(labels_trainData, class_number_train)
    调用 protonets 实例的 train() 方法进行模型训练,传入标签训练数据和类别数。

  14. if n % 2 == 0 and n != 0:
    如果 n 是偶数且不为0,则执行以下代码块。

  15. torch.save(protonets.model, './log/model_net_' + str(n) + '.pkl')
    保存模型到 './log/model_net_' + str(n) + '.pkl' 的文件路径。

  16. protonets.save_center('./log/model_center_' + str(n) + '.csv')
    调用 protonets 实例的 save_center() 方法,将模型的中心点保存到 './log/model_center_' + str(n) + '.csv'

  17. test_accury = protonets.evaluation_model(labels_testData, class_number_test)
    调用 protonets 实例的 evaluation_model() 方法,对模型进行评估并返回测试准确率,将其保存到 test_accury 变量中。

  18. print(test_accury)
    打印测试准确率。

  19. str_data = str(n) + ',' + str(' test_accury ') + str(test_accury) + '\n'
    构建一个字符串以保存到文件中。

  20. with open('./log/model_step_eval.txt', "a") as f:
    打开一个文件,以追加模式写入。


总结

原型网络(Prototypical Network)是一种用于小样本学习的模型,由Jake Snell等人于2017年提出。它是一种基于元学习(meta-learning)的方法,主要用于解决在具有少量标记样本的情况下进行分类任务的问题。

传统的深度学习模型在处理小样本学习时通常表现不佳,因为它们需要大量的标记样本来进行训练。然而,在现实世界中,我们往往只有少量标记样本可用。原型网络通过引入一个用于表示类别的中心向量(原型)的概念,解决了这个问题。

原型网络的功能和优势如下:

  1. 小样本学习:原型网络适用于具有少量标记样本的分类任务,可以在只有几个样本可用时进行准确的分类。

  2. 元学习能力:原型网络通过学习类别的原型向量,能够在遇到新类别时进行快速学习,从而实现元学习的目标。

  3. 欧氏距离度量:原型网络使用欧氏距离来度量样本与原型之间的相似性,从而进行分类推断。这种度量方式非常直观和可解释,使得模型更易于理解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/193247.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

odoo17前端js框架的演化

odoo17发布了,从界面上看,变化还是很明显的,比16更漂亮了,本来以为源码不会发生太大的变化,结果仔细一瞧,变化也不小。 1、打包好的文件数量和大小发生了变化 打包好的文件从两个变成了一个,在…

原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列7(承接系列6)

文章目录 前言一、原始代码---保存原型点,加载原型点二、代码逐行解释 前言 此部分为原型网络的两个函数,分别为保存原型点函数和加载原型点函数,与之前的系列相承接。 一、原始代码—保存原型点,加载原型点 def save_center(self,path):datas []for …

【算法挨揍日记】day29——139. 单词拆分、467. 环绕字符串中唯一的子字符串

139. 单词拆分 139. 单词拆分 题目描述: 给你一个字符串 s 和一个字符串列表 wordDict 作为字典。请你判断是否可以利用字典中出现的单词拼接出 s 。 注意:不要求字典中出现的单词全部都使用,并且字典中的单词可以重复使用。 解题思路&am…

《2020年最新面经》—字节跳动Java社招面试题

文章目录 前言:一面:01、Java基础知识答疑,简单概述一下?02、倒排索引了解吗?使用Java语言怎么实现倒排?03、详细讲解一下redis里面的哈希表,常用的Redis哈希表命名有哪些,举例说明其…

科大讯飞会议笔记本、GoodNotes、E人E本 功能及体验对比

科大讯飞会议笔记本、GoodNotes、E人E本功能及体验对比 【旧文档,怕失传】 通过对科大讯飞会议笔记本、基于iPad的GoodNotes以及E人E本的各项功能指标进行了实际对比,得出了以下结果: 在实际体验中,科大讯飞笔记本在录音方面表…

C/C++ 获取主机网卡MAC地址

MAC地址(Media Access Control address),又称为物理地址或硬件地址,是网络适配器(网卡)在制造时被分配的全球唯一的48位地址。这个地址是数据链路层(OSI模型的第二层)的一部分&#…

STL的介绍

STL 是 C 标准模板库(Standard Template Library)的缩写,是 C 标准库中的一个重要组成部分。STL 提供了一组通用的模板类和函数,用于实现常用的数据结构和算法,如向量(vector)、链表&#xff08…

Alien Skin Exposure2024免费版图片颜色滤镜插件

Alien Skin Exposure一款非常专业的图片后期处理软件,内含500多种照片滤镜。是一款图片后期处理功能非常强大的软件。这款软件可以对图片的后期效果做很好的处理。 打开Alien Skin Exposure软件,会显示下面这个界面,如图1. ExposureX8win-安…

vue下载xlsx表格

vue下载xlsx表格 // 导入依赖库 import XLSX from xlsx; import FileSaver from file-saver; methods:{btn(){let date new Date()let Y date.getFullYear() -let M (date.getMonth() 1 < 10 ? 0 (date.getMonth() 1) : date.getMonth() 1) -let D (date.getDat…

vue引入前端工程内的图片

一、public目录下的图片 public目录下的图片引入方式&#xff1a; <!--/images/图片名称&#xff0c;这种属于绝对路径&#xff0c;/指向public目录 --> <img src"/images/image.png"> 二、src目录下的图片 先在vue.config.js进行配置&#xff0c;并指…

周年纪念篇

一周年纪念&#xff01; 凌晨逛手机版csdn时才突然发现已经错过一周年了&#xff0c;但我当闰年来纪念一下不过分吧hhh 浅浅的整些怀念的东西吧&#xff01; 这是人生第一段代码&#xff1a;不是hello world写不起&#xff0c;而是纯爱单推人更有性价比。 有这段代码在&#x…

移动端表格分页uni-app

使用uni-app提供的uni-table表格 网址&#xff1a;https://uniapp.dcloud.net.cn/component/uniui/uni-table.html#%E4%BB%8B%E7%BB%8D <uni-table ref"table" :loading"loading" border stripe type"selection" emptyText"暂无更多数据…