机器学习深度学习——NLP实战(自然语言推断——微调BERT实现)

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——针对序列级和词元级应用微调BERT
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

NLP实战(自然语言推断——微调BERT实现)

  • 引入
  • 加载预训练的BERT
  • 微调BERT的数据集
  • 微调BERT
  • 小结

引入

在之前,已经为SNLI数据集上的自然语言推断任务设计了一个基于注意力的结构,文章链接:
机器学习&&深度学习——NLP实战(自然语言推断——注意力机制实现)
现在,我们通过微调BERT来重新审视这项任务。正如上一节讨论的那样,自然语言推断是一个序列级别的文本对分类问题,而微调BERT只需要一个额外的基于多层感知机的架构,如下图所示:
在这里插入图片描述
这边将下载一个已经预训练好的小版本BERT,然后对其进行微调,一遍在SNLI数据集上进行自然语言推断。

import json
import multiprocessing
import os
import torch
from torch import nn
from d2l import torch as d2l

加载预训练的BERT

原始的BERT模型有数以亿计的参数。在下面,我们提供了两个版本的预训练BERT:“bert.base”与原始BERT基础模型一样大,需要大量计算资源才能进行微调,而“bert.small”是一个小版本,以便于演示。

d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip','225d66f04cae318b841a13d32af3acc165f253ac')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip','c72329e68a732bef0452e4b96a1c341c8910f81f')

两个预训练好的BERT模型都包含一个定义词表的“vocab.json”文件和一个预训练参数的“pretrained.params”文件。我们实现了以下load_pretrained_model函数来加载预先训练好的BERT参数。

def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,num_heads, num_layers, dropout, max_len, devices):data_dir = d2l.download_extract(pretrained_model)# 定义空词表以加载预定义词表vocab = d2l.Vocab()vocab.idx_to_token = json.load(open(os.path.join(data_dir,'vocab.json')))vocab.token_to_idx = {token: idx for idx, token in enumerate(vocab.idx_to_token)}bert = d2l.BERTModel(len(vocab), num_hiddens, norm_shape=[256],ffn_num_input=256, ffn_num_hiddens=ffn_num_hiddens,num_heads=4, num_layers=2, dropout=0.2,max_len=max_len, key_size=256, query_size=256,value_size=256, hid_in_features=256,mlm_in_features=256, nsp_in_features=256)# 加载预训练BERT参数bert.load_state_dict(torch.load(os.path.join(data_dir,'pretrained.params')))return bert, vocab

为了便于在大多数机器上演示,我们将在本节中加载和微调经过预训练BERT的小版本(“bert.small”)。在练习中,我们将展示如何微调大得多的“bert.base”以显著提高测试精度。

devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model('bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,num_layers=2, dropout=0.1, max_len=512, devices=devices)

微调BERT的数据集

对于SNLI数据集的下游任务自然语言推断,我们定义了一个定制的数据集类SNLIBERTDataset。在每个样本中,前提和假设形成一对文本序列,并被打包成一个BERT输入序列。片段索引用于区分BERT输入序列中的前提和假设。利用预定义的BERT输入序列的最大长度(max_len),持续移除输入文本对中较长文本的最后一个标记,直到满足max_len。为了加速生成用于微调BERT的SNLI数据集,我们使用4个工作进程并行生成训练或测试样本。

class SNLIBERTDataset(torch.utils.data.Dataset):def __init__(self, dataset, max_len, vocab=None):all_premise_hypothesis_tokens = [[p_tokens, h_tokens] for p_tokens, h_tokens in zip(*[d2l.tokenize([s.lower() for s in sentences])for sentences in dataset[:2]])]self.labels = torch.tensor(dataset[2])self.vocab = vocabself.max_len = max_len(self.all_token_ids, self.all_segments,self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)print('read ' + str(len(self.all_token_ids)) + ' examples')def _preprocess(self, all_premise_hypothesis_tokens):pool = multiprocessing.Pool(4)  # 使用4个进程out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)all_token_ids = [token_ids for token_ids, segments, valid_len in out]all_segments = [segments for token_ids, segments, valid_len in out]valid_lens = [valid_len for token_ids, segments, valid_len in out]return (torch.tensor(all_token_ids, dtype=torch.long),torch.tensor(all_segments, dtype=torch.long),torch.tensor(valid_lens))def _mp_worker(self, premise_hypothesis_tokens):p_tokens, h_tokens = premise_hypothesis_tokensself._truncate_pair_of_tokens(p_tokens, h_tokens)tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \* (self.max_len - len(tokens))segments = segments + [0] * (self.max_len - len(segments))valid_len = len(tokens)return token_ids, segments, valid_lendef _truncate_pair_of_tokens(self, p_tokens, h_tokens):# 为BERT输入中的'<CLS>'、'<SEP>'和'<SEP>'词元保留位置while len(p_tokens) + len(h_tokens) > self.max_len - 3:if len(p_tokens) > len(h_tokens):p_tokens.pop()else:h_tokens.pop()def __getitem__(self, idx):return (self.all_token_ids[idx], self.all_segments[idx],self.valid_lens[idx]), self.labels[idx]def __len__(self):return len(self.all_token_ids)

读取完SNLI数据集后,我们通过实例化SNLIBERTDataset类来生成训练和测试样本。这些样本将在自然语言推断的训练和测试期间进行小批量读取。

# 如果出现显存不足错误,请减少“batch_size”。在原始的BERT模型中,max_len=512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = "D:\Python\pytorch\data\snli_1.0\snli_1.0"
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(test_set, batch_size,num_workers=num_workers)

微调BERT

用于自然语言推断的微调BERT只需要一个额外的多层感知机,该多层感知机由两个全连接层组成(下面代码的self.hidden和self.output)。这个多层感知机将特殊的“<cls>”词元的BERT表示进行了转换,该词元同时编码前提和假设的信息为自然语言推断的三个输出:蕴涵、矛盾和中性。

class BERTClassifier(nn.Module):def __init__(self, bert):super(BERTClassifier, self).__init__()self.encoder = bert.encoderself.hidden = bert.hiddenself.output = nn.Linear(256, 3)def forward(self, inputs):tokens_X, segments_X, valid_lens_x = inputsencoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)return self.output(self.hidden(encoded_X[:, 0, :]))

在下文中,预训练的BERT模型bert被送到用于下游应用的BERTClassifier实例net中。在BERT微调的常见实现中,只有额外的多层感知机(net.output)的输出层的参数将从零开始学习。预训练BERT编码器(net.encoder)和额外的多层感知机的隐藏层(net.hidden)的所有参数都将进行微调。

net = BERTClassifier(bert)

回想一下,在之前的文章:
机器学习&&深度学习——BERT(来自transformer的双向编码器表示)
其中,我们的MaskLM类和NextSentencePred类在其使用的多层感知机中都有一些参数。这些参数是预训练BERT模型bert中参数的一部分,因此是net中参数的一部分。然而,这些参数仅用于计算预训练过程中的遮蔽语言模型损失和下一句预测损失。这两个损失函数与微调下游应用无关,因此当BERT微调时,MaskLM和NextSentencePred中采用的多层感知机的参数不会更新(陈旧的,staled)。
为了允许具有陈旧梯度的参数,标志ignore_stale_grad=True在step函数d2l.train_batch_ch13中被设置。我们通过该函数使用SNLI的训练集(train_iter)和测试集(test_iter)对net模型进行训练和评估。

lr, num_epochs = 1e-4, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)

运行结果:

loss 0.520, train acc 0.790, test acc 0.779
446.5 examples/sec on [device(type=‘cpu’)]

运行图片:
在这里插入图片描述
如果计算资源允许,比如咱们去autodl平台上租借GPU以后,可以微调一个更大的预训练BERT模型,修改load_pretrained_model函数中的参数设置:将“bert.small”替换为“bert.base”,将num_hiddens=256、ffn_num_hiddens=512、num_heads=4和num_layers=2的值分别增加到768、3072、12和12。这样的测试精度应该是会高于0.86的。

小结

1、我们可以针对下游应用对预训练的BERT模型进行微调,例如在SNLI数据集上进行自然语言推断。
2、在微调过程中,BERT模型成为下游应用模型的一部分。仅与训练前损失相关的参数在微调期间不会更新。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/81646.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何在 Opera 中启用DNS over HTTPS

DNS over HTTPS&#xff08;基于HTTPS的DNS&#xff09;是一种更安全的浏览方式&#xff0c;但大多数 Web 浏览器默认情况下不启用它。了解如何在 Opera 浏览器中启用该功能。 您可能不知道这一点&#xff0c;但您的网络浏览器并不像您希望的那样私密或安全。您会看到&#xff…

【BUG】解决安装oracle11g或12C中无法访问临时位置的问题

项目场景&#xff1a; 安装oracle时&#xff0c;到第二步出现oracle11g或12C中无法访问临时位置的问题。 解决方案&#xff1a; 针对客户端安装&#xff0c;在cmd中执行命令&#xff1a;前面加实际路径setup.exe -ignorePrereq -J"-Doracle.install.client.validate.cli…

CRYPTO 密码学-笔记

一、古典密码学 1.替换法&#xff1a;用固定的信息&#xff0c;将原文替换成密文 替换法的加密方式&#xff1a;一种是单表替换&#xff0c;另一种是多表替换 单表替换&#xff1a;原文和密文使用同一张表 abcde---》sfdgh 多表替换&#xff1a;有多涨表&#xff0c;原文和密文…

Python Opencv实践 - 图像直方图自适应均衡化

import cv2 as cv import numpy as np import matplotlib.pyplot as pltimg cv.imread("../SampleImages/cat.jpg", cv.IMREAD_GRAYSCALE) print(img.shape)#整幅图像做普通的直方图均衡化 img_hist_equalized cv.equalizeHist(img)#图像直方图自适应均衡化 #1. 创…

四信5G工业路由器赋能5G LAN全连接工厂建设

5G作为“新基建”之首&#xff0c;肩负着驱动国民经济转型升级、促进实体经济与数字经济深度融合、满足各行各业高质量通信服务需求的重任。 随着5G技术的更新迭代&#xff0c;各行各业对网络的可靠性&#xff0c;确定性等提出更高的需求&#xff0c;5G LAN作为3GPP R16标准定…

3、Spring_容器执行

容器执行点 1.整合 druid 连接池 添加依赖 <dependency><groupId>com.alibaba</groupId><artifactId>druid</artifactId><version>1.2.8</version> </dependency>1.硬编码方式整合 新建德鲁伊配置 <?xml version"1.…

【网络安全】跨站脚本(xss)攻击

跨站点脚本&#xff08;也称为 XSS&#xff09;是一种 Web 安全漏洞&#xff0c;允许攻击者破坏用户与易受攻击的应用程序的交互。它允许攻击者绕过同源策略&#xff0c;该策略旨在将不同的网站彼此隔离。跨站点脚本漏洞通常允许攻击者伪装成受害者用户&#xff0c;执行用户能够…

Linux C 多进程编程(面试考点)

嵌入式开发为什么要移植操作系统&#xff1f; 1.减小软硬件的耦合度&#xff0c;提高软件的移植性 2. 操作系统提供很多库和工具&#xff08;QT Open CV&#xff09;&#xff0c;提高开发效率 3.操作系统提供多任务机制&#xff0c;______________________? (提高C…

Linux 网络编程 和 字节序的概念

网络编程概述 不同于之前学习的所有通讯方法&#xff0c;多基于Linux内核实现&#xff0c;只能在同一个系统中不同进程或线程间通讯&#xff0c;Linux的网络编程可以实现真正的多机通讯&#xff01; 两个不相关的终端要实现通讯&#xff0c;必须依赖网络&#xff0c;通过地址…

C运行时错误——error realloc(): invalid next size

在LeetCode做题时遇到一个运行时错误&#xff0c;将引起问题的原因记录一下备忘&#xff1a; 我们在malloc或calloc等API分配内存时&#xff0c;libc库除了分配给我们在参数中设定大小的内存&#xff08;可能会有内存对齐&#xff0c;实际分配的比参数设定的要多&#xff09;&…

免费清理电脑:删除垃圾文件以提升电脑性能

求助&#xff01;电脑上没有可用空间 ​“我只在电脑上存储了大约100张照片&#xff0c;为什么我的硬盘空间已满&#xff1f;电脑运行速度也变得越来越慢&#xff0c;要疯了&#xff01;现在我想安装更新的驱动程序。我可以释放磁盘空间吗&#xff1f;有免费的Windows电脑清…

【React源码实现】元素渲染的实现原理

前言 本文将结合React的设计思想来实现元素的渲染&#xff0c;即通过JSX语法的方式是如何创建为真实dom渲染到页面上&#xff0c;本文基本不涉及React的源码&#xff0c;但与React的实现思路是一致的&#xff0c;所以非常适合小白学习&#xff0c;建议跟着步骤敲代码&#xff…