探索训练人工智能模型的词汇大小与模型的维度

news/2025/1/12 6:15:10/文章来源:https://www.cnblogs.com/jellyai/p/18550377

前一篇:《人工智能同样也会读死书----“过拟合”》

序言:你看,人工智能领域的专家都在做什么?他们其实只是在不断试错,因为并没有一种“万能药”——一种万能的算法可以一次性设计出任何人工智能大模型来实现客户的需求。所有的模型在设计和训练过程中都是——验证结构——修改架构——再验证新结构——再修改……最终达到设计的目的。

说难听一点,在当前的技术背景下,从事人工智能模型的设计就是动手实践的经验主义。这个行业需要的,就是您拥有足够的背景知识,能够快速理解并做出反应,以应对这个行业中各种概念的融合。这些概念大部分来自数学、统计学、人文学、社会学、语言学等多个宽泛的领域,也正是这个行业招募硕士、博士、博士后学历背景人才的原因,因为这个行业需要有足够的理论背景来支持。好了,让我们回到重点,本节知识实际上是人为地通过对训练数据集和验证数据集中词汇的相互关系以及维度的大小,迭代模型的架构,最终达到收敛,完成人工智能模型的设计。能理解今天的知识并独立付诸实践的朋友,您就是人工智能领域的专家。大家一起来测试测试吧。

探索词汇大小

Sarcasm 数据集处理的是单词,所以如果你去研究数据集中的单词,特别是它们的频率,你可能会找到一些可以帮助解决过拟合问题的线索。tokenizer 提供了一种方法,可以通过它的 word_counts 属性做到这一点。如果你打印出来,你会看到类似这样的结果,这是一个包含单词和词频元组的 OrderedDict:

wc = tokenizer.word_counts

print(wc)

OrderedDict([('former', 75), ('versace', 1), ('store', 35), ('clerk', 8),

('sues', 12), ('secret', 68), ('black', 203), ('code', 16),...])

这些单词的顺序是由它们在数据集中出现的顺序决定的。如果你查看训练集中的第一条标题,它是一条讽刺性的新闻,关于一位 Versace 前店员的故事。停用词(stopwords)已经被移除了,否则你会看到很多像 “a” 和 “the” 这样频率很高的词。

因为这是一个 OrderedDict,你可以按词频降序对它进行排序,代码如下:

from collections import OrderedDict

newlist = OrderedDict(sorted(wc.items(), key=lambda t: t[1], reverse=True))

print(newlist)

OrderedDict([('new', 1143), ('trump', 966), ('man', 940), ('not', 555),

('just', 430), ('will', 427), ('one', 406), ('year', 386),...])

如果你想把这些数据画成图表,你可以遍历列表中的每一项,把 x 轴的值设为单词在列表中的位置(第 1 项对应 x=1,第 2 项对应 x=2,以此类推),y 轴的值设为 newlist[item] 的词频。然后你可以用 Matplotlib 绘制图表,代码如下:

xs = []

ys = []

curr_x = 1

for item in newlist:

xs.append(curr_x)

curr_x += 1

ys.append(newlist[item])

plt.plot(xs, ys)

plt.show()

绘制出的结果如图 6-6 所示。

图 6-6:探索单词的频率

这条“冰球棒”曲线告诉我们,只有很少的单词被大量使用,而大多数单词只被使用了很少的次数。但实际上,每个单词在嵌入中都被赋予了相同的权重,因为每个单词在嵌入表中都有一个“条目”。由于训练集的规模相对验证集来说较大,我们会遇到一种情况:训练集中有许多单词,而这些单词并未出现在验证集中。

你可以通过修改图表的坐标轴来放大数据,在调用 plt.show 之前调整坐标轴。例如,如果你想查看 x 轴上单词 300 到 10,000 的范围,以及 y 轴上从 0 到 100 的频率范围,可以使用以下代码:

plt.plot(xs, ys)

plt.axis([300, 10000, 0, 100])

plt.show()

结果如图 6-7 所示。

图 6-7:单词 300–10,000 的频率

虽然语料库中包含了超过 20,000 个单词,但代码仅设置为训练 10,000 个单词。然而,如果我们查看位置 2,000 到 10,000 的单词(这部分占我们词汇表的 80% 以上),可以发现它们每个单词在整个语料库中出现的次数都少于 20 次。

这可能解释了过拟合的原因。现在想象一下,如果你将词汇表大小改为两千并重新训练,会发生什么?图 6-8 显示了准确率指标。此时训练集的准确率约为 82%,验证集的准确率约为 76%。两者之间的差距变小了,而且没有分离开,这表明我们已经解决了大部分的过拟合问题,这是一个好迹象。

图 6-8:使用两千词汇表时的准确率

这一点在图 6-9 的损失曲线中得到了进一步的验证。验证集的损失确实在上升,但速度比之前慢得多。所以,将词汇表大小减少,防止训练集过拟合那些可能只在训练集中出现的低频词,似乎起到了作用。

图 6-9:使用两千词汇表时的损失

值得尝试不同的词汇表大小,但要记住,词汇表也可能太小,导致模型过拟合到过少的词汇上。你需要找到一个平衡点。在这个案例中,我选择只保留出现 20 次或更多的单词,这只是一个完全随机的选择。

探索嵌入维度

在这个例子中,嵌入维度被随意地设置为 16。这意味着单词会被编码成 16 维空间中的向量,其方向表示它们的整体意义。但 16 是个合适的数字吗?在词汇表只有两千个单词的情况下,这个维度可能有点高了,导致方向上的高稀疏性。

嵌入大小的最佳实践是让它等于词汇表大小的四次方根。两千的四次方根是 6.687,因此我们可以尝试将嵌入维度改为 7,并重新训练模型 100 个训练周期。

你可以在图 6-9 中看到准确率的结果。训练集的准确率稳定在大约 83%,而验证集的准确率在大约 77%。尽管有些波动,但曲线总体上是平的,说明模型已经收敛了。虽然这和图 6-6 的结果没有太大区别,但减少嵌入维度让模型训练速度快了大约 30%

图 6-10:七维嵌入的训练与验证准确率对比

图 6-11 展示了训练和验证的损失曲线。虽然在大约第 20 个训练周期时损失看起来有所上升,但很快就趋于平稳了。这也是一个好信号!

图 6-11:七维嵌入的训练与验证损失对比

现在我们已经降低了嵌入维度,可以对模型架构进行进一步的调整了。

探索模型架构

在经过前面部分对模型进行优化之后,模型架构现在是这样的:

model = tf.keras.Sequential([

tf.keras.layers.Embedding(2000, 7),

tf.keras.layers.GlobalAveragePooling1D(),

tf.keras.layers.Dense(24, activation='relu'),

tf.keras.layers.Dense(1, activation='sigmoid')

])

model.compile(loss='binary_crossentropy',

optimizer='adam', metrics=['accuracy'])

第一眼看过去,模型的维度设计让人感到有点问题——GlobalAveragePooling1D 层现在只输出 7 个维度的数据,但这些数据被直接传入一个有 24 个神经元的全连接层,这显得有点过头了。那么,如果把这个全连接层的神经元数量减少到 8 个,并训练 100 个周期,会发生什么呢?

你可以在图 6-12 中看到训练和验证的准确率对比。和使用 24 个神经元的图 6-7 比较,整体结果差不多,但波动明显减少了(线条看起来不那么“锯齿状”了)。而且,训练速度也有所提升。

图 6-12:减少全连接层神经元后的准确率结果

同样,图 6-13 中的损失曲线显示了类似的结果,不过“锯齿感”也有所减弱。

图 6-13:减少全连接层神经元后的损失结果

总结:本篇所讲述的知识,对于人工智能领域的工程师或专家来说,会显得很基础;但对于普通大众而言,如果能够理解并亲手实践,完成文中知识所描述的目标,那您已经具备了人工智能领域专家的能力。再次强调,人工智能并不神秘,它本质上是利用广泛的知识教会机器模拟人类的思维方式和行为模式,用机器来为人类造福,同时利用机器所具备的人类不具备的优势,提升人类的生产力与文明。接下来,我们将为大家介绍一种常用的解决模型“读死书”问题的方法——“随机失活”(Dropout)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/835254.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

北美竞赛-加拿大计算机竞赛CCC-收获滑铁卢

给定一个 RCRC 的方格矩阵。 矩阵左上角方格坐标为 (0,0)(0,0),右下角方格坐标为 (R−1,C−1)(R−1,C−1)。 每个方格中要么有南瓜,要么有干草。 南瓜分为大、中、小三种。 初始时,一个农民位于方格 (A,B)(A,B)。 他可以朝上下左右四个方向自由移动,但是他不能走出矩阵,也…

BUU CODE REVIEW 1 1

BUU CODE REVIEW 1 1 打开实例发现php代码,代码审计一波看到unserialize(),初步判断这题存在php反序列化 分析代码:需要GET传参传入pleaseget=1 需要POST传参传入pleasepost=2 需要POST传入md51和md52,使得md51的md5加密后的MD5值弱相等,参数值不相等 需要POST传入obj,用来…

25 个值得关注的检索增强生成 (RAG) 模型和框架

大型语言模型 (LLM) 如 GPT-4 彻底革新了自然语言处理 (NLP) 领域,在生成类人文本、回答问题和执行各种语言相关任务方面展现出卓越的能力。然而,这些模型也存在一些固有的局限性:知识截止:LLM 的训练数据通常截止于特定时间点,使其无法获取训练后发生的事件或信息。 静态…

IDEA不使用lombok,如何快速生成get和set方法

前言 大家好,我是小徐啊。我们在开发Java应用的时候,对于实体类,一般是entity或者pojo类,需要设置好属性的get和set方法。这是比较普通的操作。当然,现在已经有lombok这个插件和依赖来帮助我们不用写get和set方法了。不过,对于一些老系统,我还是习惯于手写get和set方法。…

爱玛单车队-冲刺日志第一天

会议记录:今天是整个冲刺计划最关键的一天,我们需要制定好整个计划并且安排好分工任务,为每个分工任务制定好负责人,来督促每个环节的任务。 本次冲刺确定了以下分工:成员姓名 职责曾庆徽 组长,分配协调组织林传昊 代码审查翁林靖 AI接回查找与测试毛震 软件测试(性能、…

达梦数据库数据类型的变更无效错误,如此解决妙啊

前言 大家好,我是小徐啊。之前在做国产化改造,用到了达梦数据库。其中的一项工作就是将旧数据库里面的数据和结构迁移到达梦数据库。达梦提供了迁移的的工具,大部分时间是挺好用的。 但是这里也有问题,比如我原来的数据库是postgresql,将它迁移到达梦数据库之后,在运行程…

域名选购操作指南

一、前言 在这个互联网时代, 域名已成为网站的数字身份证和品牌象征。它不仅是访问网站的便捷入口, 更一、前言 在这个互联网时代, 域名已成为网站的数字身份证和品牌象征。它不仅是访问网站的便捷入口, 更是树立网络品牌形象的重要资产。2024 年双十一期间, 我在腾讯云平台购置…

2024-2025-1 学号20241315《计算机基础与程序设计》第八周学习总结

作业信息这个作业属于哪个课程 2024-2025-1-计算机基础与程序设计这个作业要求在哪里 <作业要求的链接>https://www.cnblogs.com/rocedu/p/9577842.html#WEEK08这个作业目标 功能设计与面向对象设计 面向对象设计过程 面向对象语言三要素 汇编、编译、解释、执行作业正文…

2024六安市第二届网络安全大赛-misc

六安市第二届网络安全大赛复现misc听说你也喜欢俄罗斯方块?ppt拼接之后缺三个角补上flag{qfnh_wergh_wqef}流量分析流量包分离出来一个压缩包出来一张图片黑色代表0白色代表11010101000rab反的压缩包转一下密码:拾叁拾陆叁拾贰陆拾肆密文:4p4n575851324332474r324753574o594…

2024强网杯-misc

2024强网杯-misc谍影重重5.0打开发现是SMB流量,从NTLM流中找到数据来解密。用NTLMRawUnhide这个脚本 一键提取出数据。下载下来运行一下Hashcat直接爆破babygirl233再用smb流量脚本解密跑出key,再导入这个时候发现有flag的压缩包导出来压缩包需要密码,接着可以看到流量包还有…

2024网鼎杯青龙misc04

2024网鼎杯misc04Misc04首先看到一个杂乱的图片不过这是一个皮亚诺曲线上脚本from PIL import Imagefrom tqdm import tqdmdef peano(n): if n == 0: return [[0,0]] else: in_lst = peano(n - 1) lst = in_lst.copy() px,py = lst[-1] …